"""PeriodArrayWidget class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from enum import IntEnum
import os

# 2. Third party modules
import numpy as np
import pandas
import pandas as pd
from PySide2.QtCore import Qt
from PySide2.QtWidgets import (
    QComboBox,
    QHBoxLayout,
    QLabel,
    QLineEdit,
    QPushButton,
    QTabWidget,
    QToolBar,
    QVBoxLayout,
    QWidget,
)

# 3. Aquaveo modules
from xms.core import time
from xms.core.filesystem import filesystem as fs
from xms.coverage.xy import xy_util
from xms.coverage.xy.xy_series import XySeries
from xms.datasets.dataset_reader import DatasetReader, INVALID_REFTIME
from xms.guipy.dialogs import dialog_util, message_box
from xms.guipy.models.qx_pandas_table_model import QxPandasTableModel
from xms.guipy.validators import number_corrector
from xms.guipy.validators.qx_double_validator import QxDoubleValidator
from xms.guipy.validators.qx_int_validator import QxIntValidator
from xms.guipy.widgets import widget_builder

# 4. Local modules
from xms.mf6.data import time_util
from xms.mf6.data.array import Array
from xms.mf6.data.array_layer import ArrayLayer
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.file_io import io_util
from xms.mf6.gui import define_stress_period_dialog, units_util
from xms.mf6.gui import gui_util
from xms.mf6.gui.event_filter import EventFilter
from xms.mf6.gui.resources.svgs import ADD_SVG, DELETE_SVG
from xms.mf6.misc import util


class PeriodsEnum(IntEnum):
    """Options for user when applying a dataset to the layer."""
    ONLY_THIS_PERIOD = 0
    ALL_PERIODS = 1
    CANCEL = 2


class PeriodArrayWidget(QWidget):
    """A widget showing the stress period data in a table."""
    def __init__(self, data, dlg_input, parent):
        """Initializes the class, sets up the ui.

        Args:
            data: Something derived from ArrayPackageData
            dlg_input (DialogInput): Information needed by the dialog.
            parent (Something derived from QWidget): The parent window.
        """
        super().__init__(parent)

        self._data = data  # Something derived from xms.mf6.ListPackageData class
        self.dlg_input = dlg_input
        self.models = {}  # Dictionary of classes derived from QAbstractTreeModel
        self.loading_stress_period = False  # True in load_period
        self.modified_arrays = set()  # Set of array names where arrays have been modified
        self.actions = {}
        self.uix = {}  # Additional (or xtra) widgets created programmatically
        self.ui_tabs = {}  # Dict of dict of widgets in each tab: tab_name -> widget_name -> widget
        self.event_filter = EventFilter(dialog=self)
        self.loaded = False
        self.grid_info = self._data.grid_info()
        self._current_period = 1
        self._extrap = False  # If any temporal extrapolation occurred, this will be set to True

        self.setup_ui()
        self.load_period(self._current_period)
        self.setup_signals()

    def setup_ui(self):
        """Sets up the UI."""
        self.setUpdatesEnabled(False)  # To avoid signals (doesn't seem to work).

        vlayout = QVBoxLayout()
        self.setLayout(vlayout)

        # ---- Stress period spin control row
        hlayout = gui_util.setup_stress_period_spin_box_layout(self, self.uix, self._data.nper(), self.on_spn_period)
        vlayout.addLayout(hlayout)

        # Toolbar
        w = self.uix['toolbar'] = QToolBar(self)
        button_list = [
            [ADD_SVG, 'Define Period', self.on_btn_define_period],
            [DELETE_SVG, 'Delete Period', self.on_btn_delete_period]
        ]
        self.actions = widget_builder.setup_toolbar(w, button_list)
        hlayout.addWidget(w)

        # Period undefined text
        w = self.uix['txt_undefined'] = QLabel('(This period is not defined)')
        self.uix['txt_undefined'].setStyleSheet('font-weight: bold; color: red')
        hlayout.addWidget(w)
        hlayout.addStretch()
        # -----

        # Tab control
        self.uix['tab_widget'] = QTabWidget()
        self.create_tabs(True)
        self.uix['tab_widget'].setCurrentIndex(1)
        vlayout.addWidget(self.uix['tab_widget'])

        self.loaded = True

        self.setUpdatesEnabled(True)  # To enable signals.

    def do_enabling(self):
        """Enables and disables the widgets appropriately."""
        tab_index = self.uix['tab_widget'].currentIndex()
        if tab_index < 0:
            return

        tab_name = self.uix['tab_widget'].tabText(tab_index)
        cbx = self.ui_tabs[tab_name]['cbx']
        text = cbx.currentText()
        if text == 'CONSTANT':
            self.ui_tabs[tab_name]['lbl_constant'].setText('Constant:')
            enable_constant, enable_series, enable_table, enable_dataset_to_layer = True, False, False, True
        elif text == 'ARRAY':
            self.ui_tabs[tab_name]['lbl_constant'].setText('Factor:')
            enable_constant, enable_series, enable_table, enable_dataset_to_layer = True, False, True, True
        elif text == 'TIME-ARRAY SERIES':
            enable_constant, enable_series, enable_table, enable_dataset_to_layer = False, True, False, False
        else:
            enable_constant, enable_series, enable_table, enable_dataset_to_layer = False, False, False, False

        # Set constant/factor edit field text
        sp = self.period()
        array = self._data.get_array(sp, tab_name)
        if array and array.layers:
            array_layer = array.layer(0)
            constant_or_factor = array_layer.constant if text == 'CONSTANT' else array_layer.factor
            if array_layer.numeric_type == 'int':
                self.ui_tabs[tab_name]['edt_constant'].setText(str(int(constant_or_factor)))
            else:
                self.ui_tabs[tab_name]['edt_constant'].setText(str(constant_or_factor))

        not_locked = not self.dlg_input.locked

        self.ui_tabs[tab_name]['cbx'].setEnabled(not_locked)
        self.ui_tabs[tab_name]['lbl_constant'].setEnabled(not_locked and enable_constant)
        self.ui_tabs[tab_name]['edt_constant'].setEnabled(not_locked and enable_constant)
        if not self._data.is_layer_indicator(tab_name):
            self.ui_tabs[tab_name]['lbl_series'].setEnabled(not_locked and enable_series)
            self.ui_tabs[tab_name]['edt_series'].setEnabled(not_locked and enable_series)
        self.ui_tabs[tab_name]['tbl'].setEnabled(enable_table)
        self.ui_tabs[tab_name]['btn_dataset_to_layer'].setEnabled(enable_dataset_to_layer)

    def setup_signals(self):
        """Sets up any needed signals."""
        self.uix['tab_widget'].currentChanged.connect(self.on_tab_changed)
        # self.on_spn_period was already connected by gui_util.setup_stress_period_spin_box_layout()

    def create_combobox_hlayout(self, page_widgets, array_name):
        """Creates the QHBoxLayout with the combobox and other items.

        Args:
            page_widgets: Dict of widgets on the page: widget_name -> widget
            array_name: Name of array (e.g. 'RECHARGE', 'IRCH' etc.)

        Returns:
            The QHBoxLayout containing the widgets.
        """
        hlayout = QHBoxLayout()
        page_widgets['cbx'] = QComboBox()
        page_widgets['cbx'].setAccessibleName('Array type')
        page_widgets['cbx'].addItems(['UNDEFINED', 'CONSTANT', 'ARRAY', 'TIME-ARRAY SERIES'])
        page_widgets['cbx'].currentTextChanged.connect(self.on_cbx)
        hlayout.addWidget(page_widgets['cbx'])
        page_widgets['lbl_constant'] = QLabel('Constant:')
        hlayout.addWidget(page_widgets['lbl_constant'])
        page_widgets['edt_constant'] = QLineEdit()
        page_widgets['edt_constant'].setAccessibleName('Constant')
        page_widgets['edt_constant'].editingFinished.connect(self.on_constant)
        page_widgets['edt_constant'].installEventFilter(self.event_filter)
        if self._data.is_layer_indicator(array_name):
            validator = QxIntValidator(parent=page_widgets['edt_constant'])
        else:
            validator = QxDoubleValidator(parent=page_widgets['edt_constant'])
        page_widgets['edt_constant'].setValidator(validator)
        gui_util.set_widget_fixed_width(page_widgets['edt_constant'])
        hlayout.addWidget(page_widgets['edt_constant'])

        if not self._data.is_layer_indicator(array_name):
            page_widgets['lbl_series'] = QLabel('Series:')
            hlayout.addWidget(page_widgets['lbl_series'])
            page_widgets['edt_series'] = QLineEdit()
            page_widgets['edt_series'].setAccessibleName('Time series')
            page_widgets['edt_series'].editingFinished.connect(self.on_series)
            hlayout.addWidget(page_widgets['edt_series'])

            # Units
            hlayout.addStretch()
            page_widgets['lbl_units'] = QLabel('Constant:')
            units_str = self._data.get_units(array_name)
            units_str = units_util.string_from_units(self._data, units_str)
            page_widgets['lbl_units'].setText(f'Units: {units_str}')
            hlayout.addWidget(page_widgets['lbl_units'])

        else:
            hlayout.addStretch()

        return hlayout

    def create_tabs(self, use_aux: bool) -> None:
        """Creates the tabs for the dialog."""
        for i, array_name in enumerate(self._data.tab_names(use_aux)):
            page = QWidget()
            page.setAccessibleName(f'{array_name} tab')
            vlayout = QVBoxLayout()
            page.setLayout(vlayout)
            page_widgets = {}

            # Combo box row
            hlayout = self.create_combobox_hlayout(page_widgets, array_name)
            vlayout.addLayout(hlayout)

            # Table
            page_widgets['tbl'] = gui_util.new_table_view()
            page_widgets['tbl'].allow_drag_fill = True
            page_widgets['tbl'].setMinimumHeight(150)  # I just think this looks better
            vlayout.addWidget(page_widgets['tbl'])

            self._add_dataset_to_layer_button(vlayout, page_widgets)

            self.uix['tab_widget'].addTab(page, array_name)
            self.uix['tab_widget'].setTabToolTip(i, gui_util.wrap_tip(self._data.get_tool_tip(array_name)))
            self.ui_tabs[array_name] = page_widgets

    def _add_dataset_to_layer_button(self, vlayout: QVBoxLayout, widgets: dict[str, QWidget]) -> None:
        """Adds the 'Dataset to Layer' button.

        Args:
            vlayout: Vertical layout of tab.
            widgets: dict of widget names to widgets.
        """
        hlay_dset = QHBoxLayout()
        vlayout.addLayout(hlay_dset)

        # Dataset to Layer
        w = widgets['btn_dataset_to_layer'] = QPushButton('Dataset to Layer...')
        hlay_dset.addWidget(w)
        w.clicked.connect(self.on_btn_dataset_to_layer)

    def on_btn_dataset_to_layer(self) -> None:
        """Lets user pick an XMS dataset and copies the values to the array."""
        # Prompt user for dataset
        cell_count = self.grid_info.cell_count()
        project_tree = self.dlg_input.query.project_tree
        dataset_uuid, chk_state = gui_util.select_dataset_dialog_check_box(
            self, project_tree, '', cell_count, chk_box_text='Treat as step function', chk_box_state=Qt.Checked
        )
        if not dataset_uuid:
            return

        # Get the dataset
        dataset = self.dlg_input.query.item_with_uuid(dataset_uuid)

        # For now, we require both the dataset and TDIS to be able to use datetimes, or neither, but not a mix
        if _dataset_can_do_date_times(dataset) != self._data.mfsim.tdis.can_do_date_times():
            _warn_of_datetime_mismatch(_dataset_can_do_date_times(dataset), self.parent())
            return

        # Get the dataset and its number of values
        dataset = self.dlg_input.query.item_with_uuid(dataset_uuid)
        num_dset_vals = len(dataset.values[0])

        # Get the array and it's shape and size
        array = self.get_current_array()
        shape = array.layer(0).shape
        array_size = shape[0] * shape[1]

        # Make sure size is sufficient
        if num_dset_vals < array_size:
            message_box.message_with_ok(parent=self.parent(), message='Dataset size is too small. Aborting.')
            return

        # If transient, ask if they want to apply to all periods, or just this one
        rv = PeriodsEnum.ONLY_THIS_PERIOD
        if len(self._data.period_data) > 1:
            rv = _ask_which_periods(self.parent())
            if rv == PeriodsEnum.CANCEL:
                return

        # Get list of periods to apply dataset to (must be at least one, or the button couldn't have been clicked).
        tab_name = self.current_tab_name()
        periods = [self._current_period] if rv == PeriodsEnum.ONLY_THIS_PERIOD else self._defined_periods(tab_name)

        # Apply the dataset to the periods
        self._dataset_to_layer(dataset, periods, step_function=chk_state == Qt.Checked)

        # Update the dialog
        self.load_period(self._current_period)
        self.do_enabling()

        # Tell user it worked, and if we had to extrapolate
        msg = 'Dataset applied.'
        if self._extrap:
            msg += (
                '\n\nDataset time steps do not cover the range of times specified by the MODFLOW stress periods.'
                ' Values were extrapolated from the information that is available.'
            )
        message_box.message_with_ok(parent=self.parent(), message=msg)

    def _defined_periods(self, tab_name: str) -> list[int]:
        """Return a list of defined periods for the given array.

        Returns:
            See description.
        """
        defined_periods = []
        for period, block in self._data.period_data.items():
            if block and block.array(tab_name) and block.array(tab_name).layers:
                defined_periods.append(period)
        return defined_periods

    def _dataset_to_layer(self, dset: DatasetReader, periods: list[int], step_function: bool) -> None:
        """Applies the dataset values to the array.

        Args:
            dset: The dataset.
            periods: List of periods to apply the dataset to.
            step_function: If True, series will be created as a step function.
        """
        # See if we should deal with datetimes. Both the dataset and tdis need to be in datetime. We checked for this
        # already.
        tdis = self._data.mfsim.tdis
        assert _dataset_can_do_date_times(dset) == tdis.can_do_date_times()
        dset_as_dates = _dataset_can_do_date_times(dset)

        # Get period times
        period_times_df = tdis.get_period_times(dset_as_dates)
        period_times = period_times_df['Time'].tolist()

        # Get array size and shape
        tab_name = self.current_tab_name()
        array = self._data.period_data[periods[0]].array(tab_name)
        shape = array.layer(0).shape
        array_size = shape[0] * shape[1]

        # Create XySeries from dataset
        dset_xy_series = _xy_series_from_dataset(dset, array_size, step_function)

        # Iterate over defined periods
        data_type = 'int' if self._data.is_layer_indicator(tab_name) else 'float'
        period_data = self._data.period_data  # for short
        self._extrap = False
        for period in periods:
            array = period_data[period].array(tab_name)
            extrap = _dataset_to_layer_for_period(period, period_times, dset_xy_series, data_type, array)
            if extrap:
                self._extrap = True

        # Don't report extrapolation if dset is steady state.
        if len(dset.times) == 1:
            self._extrap = False

        self.modified_arrays.add(tab_name)

    def period(self) -> int:
        """Returns current stress period in spin control. Shorter for convenience."""
        return self.uix['spn_period'].value()

    def previous_defined_period(self):
        """Returns the first defined stress period prior to the current one."""
        for sp in range(self.period() - 1, 0, -1):
            if self.is_period_defined(sp):
                return sp
        return -1

    def next_defined_period(self):
        """Returns the first next stress period after the current one."""
        for sp in range(self.period() + 1, self._data.nper()):
            if self.is_period_defined(sp):
                return sp
        return -1

    def copy_period(self, from_sp: int, to_sp: int):
        """Copies a stress period.

        Args:
            from_sp: The stress period being copied.
            to_sp: The stress period being copied to.
        """
        self._data.copy_period(from_sp, to_sp)

    def on_btn_define_period(self) -> None:
        """Allows the user to define data for the current stress period."""
        period = self.period()
        prev_sp = self.previous_defined_period()
        next_sp = self.next_defined_period()
        if prev_sp == -1 and next_sp == -1:
            # No previous or next stress period to copy. Create one.
            self._data.add_period(period)
        else:
            rv = define_stress_period_dialog.run(prev_sp, next_sp, self)
            if rv != -1:
                if rv == 0:
                    self.copy_period(prev_sp, period)
                elif rv == 1:
                    self.copy_period(next_sp, period)
                else:
                    self._data.add_period(period)
        self.load_period(period)

    def on_btn_delete_period(self):
        """Deletes the current stress period."""
        rv = message_box.message_with_n_buttons(
            parent=self,
            message='Delete this period?',
            button_list=['Delete Period', 'Cancel'],
            default=1,
            escape=1,
            icon='Warning',
            win_icon=util.get_app_icon()
        )
        if rv == 0:
            del self._data.period_data[self.period()]
            self.load_period(self.period())

    def on_spn_period(self, period: int) -> None:
        """Saves any changes to the current stress period and loads the next stress period.

        Args:
            period: Stress period number (1-based).
        """
        if not self.loaded:
            return

        self.save_period_to_temp()  # Save the old period
        self._current_period = period
        self.load_period(period)
        self.do_enabling()

    def on_tab_changed(self):
        """Called when the current tab changes."""
        if self.loaded:  # This is here to stop it while we're building the tabs
            self.do_enabling()

    def check_aux_change(self, block: str, new_aux: list[str]) -> str:
        """Return an error message if the new aux list is not valid, or '' if it is.

        PyCharm says this method could be a function, but it's also defined in PeriodArrayWidget so that wouldn't make
        sense.

        Args:
            block: Name of the list block.
            new_aux: New aux variable list.

        Returns:
            See description.
        """
        tab_names = self._data.tab_names(use_aux=False)
        msg = ''
        if any(aux in tab_names for aux in new_aux):
            msg = 'Auxiliary variables cannot have the same name as a standard array.'
        return msg

    def change_aux_variables(self, block: str, use_aux: bool):
        """Updates AUX in all stress periods if changes were made.

        Args:
            block: This is only here to match ListBlockTableWidget.change_aux_variables()
            use_aux: True to include AUXILIARY variables.
        """
        # save any edits to temp
        self.save_period_to_temp()

        # See if there are changes
        tab_widget = self.uix['tab_widget']
        tab_names = self._get_tab_names()
        data_tab_names = self._data.tab_names(use_aux)
        if tab_names == data_tab_names:
            return  # No change

        # remove AUX array data that is not in the new AUX list
        for array_name in tab_names:
            if array_name not in data_tab_names:
                for block in self._data.period_data.values():
                    if array_name in block.names:
                        if block.array(array_name).temp_external_filename:
                            temp = block.array(array_name).temp_external_filename
                            fname = fs.resolve_relative_path(os.path.dirname(self._data.filename), temp)
                            os.remove(fname)
                        block.delete_array(array_name)
                        break

        # delete all the tabs
        current_tab = tab_widget.currentIndex()
        self.uix['tab_widget'].clear()

        # create the tabs
        self.create_tabs(use_aux)
        self.load_period(self.period())
        tab_widget.setCurrentIndex(min(current_tab, tab_widget.count()))

    def _get_tab_names(self) -> list[str]:
        """Return the tab text."""
        tab_widget = self.uix['tab_widget']
        tab_names = [tab_widget.tabText(tab_index) for tab_index in range(tab_widget.count())]
        return tab_names

    def current_tab_name(self) -> str:
        """Returns the name of the current tab."""
        tab_index = self.uix['tab_widget'].currentIndex()
        tab_name = self.uix['tab_widget'].tabText(tab_index)
        return tab_name

    def get_current_array(self) -> Array:
        """Returns the Array class for the current stress period and tab.

        Returns:
            (Array): The Array class.
        """
        tab_name = self.current_tab_name()
        sp = self.period()
        array = self._data.get_array(sp, tab_name)
        return array

    def on_constant(self):
        """Called when finished editing the constant/factor field, and by the event filter."""
        if not self.loaded or self.loading_stress_period:
            return

        tab_name = self.current_tab_name()
        text = self.ui_tabs[tab_name]['edt_constant'].text()
        if not text:
            self.do_enabling()  # This will put the text back in the edit field

        array = self.get_current_array()
        array_layer = array.layer(0)
        if self.ui_tabs[tab_name]['cbx'].currentText() == 'CONSTANT':
            array_layer.constant = float(text)
        else:
            array_layer.factor = float(text)

    def on_series(self):
        """Called when finished editing the time-array series field."""
        if not self.loaded or self.loading_stress_period:
            return
        array = self.get_current_array()
        array.time_array_series = self.ui_tabs[self.current_tab_name()]['edt_series'].text()

    def on_cbx(self, text: str) -> None:
        """Called when a combo box changes.

        Args:
            text (str): The new text.
        """
        if not self.loaded or self.loading_stress_period:
            return

        # Get the array for current tab
        tab_index = self.uix['tab_widget'].currentIndex()
        tab_name = self.uix['tab_widget'].tabText(tab_index)
        sp = self.period()
        array = self._data.get_array(sp, tab_name)
        if not array:
            array = self._data.add_transient_array(sp, tab_name)

        # Create an array_layer if necessary
        if text == 'ARRAY':
            array_layer = array.ensure_layer_exists(make_int=self._data.is_layer_indicator(tab_name), name=tab_name)
            if not array_layer.external_filename and not array_layer.temp_external_filename:
                _, shape = ArrayLayer.number_of_values_and_shape(layered=True, grid_info=self.grid_info)
                array_layer.shape = shape

                if self._data.is_layer_indicator(tab_name):
                    narray = np.ones(shape, dtype=int)
                else:
                    narray = np.zeros(shape)

                data_frame = pandas.DataFrame(data=narray)
                data_frame.index += 1
                data_frame.columns += 1
                self.setup_model(data_frame, tab_name)
                self.modified_arrays.add(tab_name)
        elif text == 'CONSTANT':
            array.ensure_layer_exists(make_int=self._data.is_layer_indicator(tab_name), name=tab_name)
        elif text == 'TIME-ARRAY SERIES':
            if not self._data.array_supports_time_array_series(tab_name):
                message_box.message_with_ok(
                    parent=self, message=f'Time-array series not supported for array: {tab_name}'
                )
                self.ui_tabs[tab_name]['cbx'].setCurrentText('CONSTANT')
                return
            else:
                self.ui_tabs[tab_name]['edt_series'].setText(array.time_array_series)

        array.storage = text

        self.do_enabling()

    def setup_model(self, data_frame, array_name):
        """Creates a model and connects it to the table view and some signals.

        Args:
            data_frame (pandas.DataFrame): The DataFrame.
            array_name (str): The name of the array ('RECHARGE' etc)
        """
        self.models[array_name] = QxPandasTableModel(data_frame)
        # _, defaults = self.column_info()
        # self.model.set_default_values(defaults)
        if self.dlg_input.locked:
            read_only_columns = set(range(self.models[array_name].columnCount()))
            self.models[array_name].set_read_only_columns(read_only_columns)
        self.models[array_name].dataChanged.connect(self.on_data_changed)
        table_view = self.ui_tabs[array_name]['tbl']
        table_view.setModel(self.models[array_name])
        # self.ui.tbl_surface.selectionModel().selectionChanged.connect(self.on_selection_changed)

    def is_period_defined(self, period):
        """Returns true if the stress period is defined in the package data.

        Args:
            period (int): Stress period number (1-based).

        Returns:
             (bool): True if stress period is defined, else False.
        """
        if not self._data:
            return False
        return period in self._data.period_data

    def read_external_array(self, external_filename: str, numeric_type: str) -> pd.DataFrame:
        """Reads the array in the external file and returns a pandas dataframe.

        Args:
            external_filename: The full path to the external file.
            numeric_type: 'int' or 'float'.

        Returns:
            (pandas.DataFrame): pandas DataFrame of the array.
        """
        caster = io_util.type_caster_from_string(numeric_type)
        dtype = float if numeric_type == 'float' else int

        if self.grid_info.dis_enum == DisEnum.DIS:
            narray = np.empty((self.grid_info.nrow, self.grid_info.ncol), dtype=dtype)
            ncol = self.grid_info.ncol
        else:  # Must be DISV
            narray = np.empty((self.grid_info.ncpl, 1), dtype=dtype)
            ncol = 1

        row = 0
        col = 0
        with open(external_filename, 'r') as external_file:
            for line in external_file:
                words = line.split()
                for word in words:
                    narray[row, col] = caster(word)
                    col += 1
                    if col == ncol:
                        col = 0
                        row += 1

        data_frame = pandas.DataFrame(data=narray)
        data_frame.columns += 1
        data_frame.index += 1
        return data_frame

    def load_period(self, period: int) -> None:
        """Loads the stress period.

        If stress period is not defined, displays "Undefined" label and undims
        the "Define Stress Period" button.

        Args:
            period: Stress period number (1-based).
        """
        self.loading_stress_period = True

        # Enable/disable table view and "Define Stress Period" button
        not_locked = not self.dlg_input.locked
        defined = self.is_period_defined(period)
        self.uix['tab_widget'].setEnabled(defined)
        self.actions[':/resources/icons/add.svg'].setEnabled(not_locked and not defined)
        self.actions[':/resources/icons/delete.svg'].setEnabled(not_locked and defined)

        # Stop here if the stress period is not defined
        self.uix['txt_undefined'].setVisible(not defined)
        if not defined:
            return

        for tab_index in range(self.uix['tab_widget'].count()):
            tab_name = self.uix['tab_widget'].tabText(tab_index)
            array = self._data.get_array(period, tab_name)
            self.ui_tabs[tab_name]['tbl'].setModel(None)  # Clear the table
            if array:
                cbx = self.ui_tabs[tab_name]['cbx']

                storage = array.layer(0).storage if array.layers else array.storage
                if storage == 'CONSTANT':
                    cbx.setCurrentText('CONSTANT')

                elif storage == 'ARRAY':
                    array_layer = array.layer(0)
                    cbx.setCurrentText('ARRAY')

                    # Read the file
                    with dialog_util.wait_cursor_context():
                        if array_layer.temp_external_filename:
                            # Create column names
                            column_names = []
                            if self.grid_info.dis_enum == DisEnum.DIS:
                                for col in range(self.grid_info.ncol):
                                    column_names.append(str(col + 1))
                            else:
                                column_names.append('1')
                            data_frame = gui_util.read_csv_file_into_dataframe(
                                array_layer.temp_external_filename, column_names, None
                            )
                        else:
                            data_frame = self.read_external_array(
                                array_layer.external_filename, array_layer.numeric_type
                            )

                    self.setup_model(data_frame, tab_name)
                elif storage == 'TIME-ARRAY SERIES':
                    cbx.setCurrentText('TIME-ARRAY SERIES')
                    self.ui_tabs[tab_name]['edt_series'].setText(array.time_array_series)

            else:
                cbx = self.ui_tabs[tab_name]['cbx']
                cbx.setCurrentText('UNDEFINED')

        self.loading_stress_period = False
        self.do_enabling()

    def on_data_changed(self, top_left_index, bottom_right_index):
        """Called when the data in the table view has changed.

        Args:
            top_left_index (QModelIndex): Top left index.
            bottom_right_index (QModelIndex): Bottom right index.
        """
        del top_left_index, bottom_right_index  # Unused parameters
        tab_index = self.uix['tab_widget'].currentIndex()
        tab_name = self.uix['tab_widget'].tabText(tab_index)
        self.modified_arrays.add(tab_name)

    def save_period_to_temp(self):
        """If there have been changes to the stress period, saves the changes to a temporary csv file."""
        # Remove old temp file if one exists and update with new temp filename
        period = self._current_period
        with dialog_util.wait_cursor_context():
            for array_name in self.modified_arrays:
                array = self._data.get_array(period, array_name)
                if array and array.layers:
                    array_layer = array.layer(0)

                    # Remove old temp file if one exists and update with new temp filename
                    if array_layer.temp_external_filename:
                        os.remove(array_layer.temp_external_filename)

                    # Save the current values
                    if array_name in self.models:
                        df = self.models[array_name].data_frame
                        array_layer.temp_external_filename = gui_util.dataframe_to_temp_file(df)

        self.modified_arrays.clear()

    def clean_up_temp_files(self):
        """Deletes all temp files."""
        with dialog_util.wait_cursor_context():
            for block in self._data.period_data.values():
                for array in block.arrays:
                    if array.storage == 'ARRAY':
                        array_layer = array.layer(0)
                        if array_layer.temp_external_filename:
                            os.remove(array_layer.temp_external_filename)

    def accept(self):
        """Called when user hits the OK button. Saves the widget state to self._data and closes the dialog."""
        if not self.dlg_input.locked:
            self.save_period_to_temp()

    def reject(self):
        """Called when the user clicks Cancel."""
        self.clean_up_temp_files()


def _dataset_to_layer_for_period(
    period: int, period_times: list, dset_series: list[XySeries], data_type: str, array: Array
) -> bool:
    """Applies the dataset to the layer for the given period.

    Args:
        period: Stress period number (1-based).
        period_times: Times associated with the periods: floats or datetimes.
        dset_series: Time series for each cell, created from the dataset.
        data_type: 'int' or 'float'.
        array: The array

    Returns:
        True if time extrapolation was used.
    """
    # Get period start/end times. If datetimes, convert to Julian
    start_time = period_times[period - 1]
    end_time = period_times[period]
    if not util.is_number(start_time):
        start_time = time.datetime_to_julian(start_time)
        end_time = time.datetime_to_julian(end_time)

    # Get value at each for the stress period by getting the average dataset value for the period
    values = []
    shape = array.layer(0).shape
    array_size = shape[0] * shape[1]
    time_extrap = False
    for cell_idx in range(array_size):
        series = dset_series[cell_idx]
        value, extrap = xy_util.average_y_from_x_range(series, start_time, end_time)
        if data_type == 'int':
            values.append(int(round(value)))  # Convert float to int
        else:
            # Attempt to get rid of "precision reveal" with the 8 byte float values from the dataset
            values.append(float(number_corrector.format_float(value)))
        if extrap:
            time_extrap = True

    # We don't do anything with the dataset activity. What would we do?
    array.dataset_to_layer(values, 1, shape)
    array.dump_to_temp_files()
    return time_extrap


def _dataset_can_do_date_times(dset: DatasetReader) -> bool:
    """Return True if dataset can do times as datetimes.

    Args:
        dset: The dataset.

    Returns:
        See description
    """
    return bool(dset.ref_time and dset.ref_time != INVALID_REFTIME and dset.time_units and dset.time_units != 'None')


def _xy_series_from_dataset(dset: DatasetReader, first_n_cells: int, step_function: bool) -> list[XySeries]:
    """Return the XySeries for first n cells for the dataset.

    If dataset has a valid ref_time and time_units, time step times are converted to Julian.

    Args:
        dset: The dataset.
        first_n_cells: How many cells to create XySeries for.
        step_function: If True, series will be created as a step function.

    Returns:
        See description.
    """
    xy_series = []

    # Get x vals (times). They are the same for every cell / series.
    x_vals = list(dset.times)

    # Convert times to julian if there's a ref_time
    if _dataset_can_do_date_times(dset):
        new_x_vals = []
        for ts_time in x_vals:
            dt = time_util.compute_end_date_py(dset.ref_time, ts_time, dset.time_units.upper())
            julian = time.datetime_to_julian(dt)
            new_x_vals.append(julian)
        x_vals = new_x_vals

    # Read entire dataset into RAM
    first_n_values_all_ts = [dset.values[ts_idx][:first_n_cells] for ts_idx in range(len(dset.times))]

    # Create an XySeries for each cell
    for cell_idx in range(first_n_cells):
        y_vals = [first_n_values_all_ts[ts_idx][cell_idx] for ts_idx in range(len(dset.times))]
        y_vals = [float(number_corrector.format_float(value)) for value in y_vals]
        if step_function:
            x_vals, y_vals = xy_util.get_step_function(x_vals, y_vals)
        xy_series.append(XySeries(x_vals, y_vals, series_id=cell_idx + 1))
    return xy_series


def _warn_of_datetime_mismatch(dataset_can_do_date_times: bool, win_cont: QWidget) -> None:
    """Warn the user if the dataset can do datetimes and TDIS can't, or vice versa.

    Dataset can be datetimes if it has a ref_time and time_units. TDIS can be datetimes if it has TIME_UNITS and
    START_TIME. We could be smarter and ignore that the dataset can do datetimes if TDIS can't, or ignore that the
    TDIS can do datetimes if the dataset can't, but for now we're requiring the user to be consistent.

    Args:
        dataset_can_do_date_times: The dataset can do datetimes.
        win_cont: The window container.

    Returns:
        See description.
    """
    if dataset_can_do_date_times:
        msg = (
            'Dataset has a reference time, but TDIS package does not (START_TIME). Either disable the'
            ' reference time for the dataset, or turn on the START_TIME option in TDIS.'
        )
    else:
        msg = (
            'Dataset does not have a reference time, but TDIS package does (START_TIME). Either define a'
            ' reference time for the dataset, or turn off the START_TIME option in TDIS.'
        )
    message_box.message_with_ok(parent=win_cont, message=msg)


def _ask_which_periods(win_cont: QWidget) -> PeriodsEnum:
    """Ask user which stress periods to apply the dataset to and return their response.

    Args:
        win_cont: The window container.

    Returns:
        PeriodsEnum.
    """
    msg = 'Apply dataset to which stress periods?'
    buttons = ['Only this period', 'All defined periods', 'Cancel']
    rv = message_box.message_with_n_buttons(win_cont, msg, 'GMS', button_list=buttons, default=0, escape=2)
    return rv
