"""PeriodListWidget class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import csv
import os
from pathlib import Path
import shlex
import sys
from typing import Any

# 2. Third party modules
import numpy as np
import pandas as pd
from PySide2.QtCore import QEvent, QItemSelectionModel, QModelIndex, QObject, QPoint, Qt
from PySide2.QtGui import QIcon
from PySide2.QtSql import QSqlDatabase, QSqlQuery, QSqlRecord, QSqlTableModel
from PySide2.QtWidgets import QAbstractButton, QDialog, QHBoxLayout, QLabel, QToolBar, QVBoxLayout, QWidget

# 3. Aquaveo modules
from xms.api.tree import tree_util
from xms.core.filesystem import filesystem as fs
from xms.guipy.delegates.edit_field_validator import EditFieldValidator
from xms.guipy.delegates.qx_cbx_delegate import QxCbxDelegate
from xms.guipy.dialogs import dialog_util, message_box
from xms.guipy.dialogs.xy_series_editor import XySeriesEditor
from xms.guipy.models.qx_sql_table_model import QxSqlTableModel
from xms.guipy.resources import resources_util
from xms.guipy.validators.qx_double_validator import QxDoubleValidator
from xms.guipy.widgets import widget_builder
from xms.guipy.widgets.widget_builder import ActionRec

# 4. Local modules
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.data.mfcellid import MfCellId
from xms.mf6.file_io import database, io_util
from xms.mf6.file_io.qx_queries import QxQueries
from xms.mf6.gui import copy_period_dialog, define_stress_period_dialog
from xms.mf6.gui import gui_util
from xms.mf6.gui.add_stresses_dialog import AddStressesData, AddStressesDialog, WhichPeriods
from xms.mf6.gui.delegates.table_text_range_color_delegate import TableTextRangeColorDelegate
from xms.mf6.gui.dialog_input import DialogInput
from xms.mf6.gui.resources.svgs import (
    ADD_DIS_SVG, ADD_SVG, COPY_DIS_SVG, COPY_SVG, DELETE_DIS_SVG, DELETE_SVG, DOWNLOAD_DIS_SVG, DOWNLOAD_SVG,
    FILTER_SVG, OPEN_DIS_SVG, OPEN_SVG, PASTE_SVG, PLOT_DIS_SVG, PLOT_SVG, ROW_ADD_DIS_SVG, ROW_ADD_SVG,
    ROW_DELETE_DIS_SVG, ROW_DELETE_SVG, ROW_INSERT_SVG, TO_COLUMN_DIS_SVG, TO_COLUMN_SVG, UPLOAD_DIS_SVG, UPLOAD_SVG
)
from xms.mf6.misc import util

# Constants
XY_UNEDITABLE = 'XY_UNEDITABLE'  # Special text in y value telling XySeriesEditor that the period is not defined

# Columns in df_orig
DF_ORIG_COL_PERIOD = 0  # 'PERIOD' column
DF_ORIG_COL_ID = 2  # 'id' column in df_copy
DF_ORIG_COL_Y = 3  # value (y) column

# Columns in df2
DF2_COL_Y = 1


class CtrlAWatcher(QObject):
    """Event filter to catch CTRL+A keyboard commands in the table."""
    def __init__(self, func):
        """Constructor.

        Args:
            func (callable): No arg method to run when CTRL+A is pressed
        """
        super().__init__(None)  # If you set the parent on this sucker, it leads to double deletes.
        self.func = func

    def eventFilter(self, receiver, event):  # noqa N802
        """Catch the CTRL-A key press event.

        Args:
            receiver (QObject): Unused
            event (QEvent): The event

        Returns:
            (bool): False, always.
        """
        if event.type() == QEvent.KeyPress and event.modifiers() == Qt.ControlModifier and event.key() == Qt.Key_A:
            self.func()
        return False


class PeriodListWidget(QWidget):
    """A widget showing the stress period data in a table."""
    def __init__(self, data, dlg_input: DialogInput, parent: QWidget):
        """Initializes the class, sets up the ui.

        Args:
            data: The ListPackageData derived package.
            dlg_input: Information needed by the dialog.
            parent: The parent window.
        """
        super().__init__(parent)

        self._data = data  # Something derived from xms.mf6.ListPackageData class
        self.dlg_input = dlg_input
        self.model = None  # Class derived from QAbstractTreeModel
        self.data_changed = False  # Flag set when a change is made in the table
        self.time_series_lookup: dict[str, str] = {}  # Dict of time series names and their files
        self.actions = {}
        self.uix = {}  # Additional (or xtra) widgets created programmatically
        self._period_times: pd.DataFrame | None = None  # Period start times (the last one is an end time)
        self.db_filename = database.database_filepath(self._data.filename)
        self._queries: QxQueries | None = None
        self._last_csv_file = ''
        self._appending_rows = False

        self.setup_ui()
        self.setup_db()
        self.setup_model()
        self.load_period(1)
        self.do_enabling()

    def setup_ui(self):
        """Sets up the UI."""
        self.setUpdatesEnabled(False)  # To avoid signals (doesn't seem to work).

        vlayout = QVBoxLayout()
        self.setLayout(vlayout)
        self._set_up_stress_period_hlay(vlayout)
        self._set_up_filter_text(vlayout)
        self._set_up_table(vlayout)
        self._set_up_table_toolbar(vlayout)

        # Signals
        self.setUpdatesEnabled(True)  # To enable signals.

    def _set_up_filter_text(self, vlayout: QVBoxLayout) -> None:
        """Set up the text that appears when filtering on selected cells.

        Args:
            vlayout: The main vertical layout.
        """
        nselected = len(self.dlg_input.selected_cells)
        cell_word = 'cells' if nselected > 1 else 'cell'
        self.uix['txt_filtering'] = QLabel(f'Filtering on {nselected} selected {cell_word}')
        self.uix['txt_filtering'].setStyleSheet('font-weight: bold; color: red')
        self.uix['txt_filtering'].setVisible(False)  # This is made visible in do_enabling()
        vlayout.addWidget(self.uix['txt_filtering'])

    def _set_up_stress_period_hlay(self, vlayout: QVBoxLayout) -> None:
        """Sets up the horizontal layout containing the stress period combo box and toolbar.

        Args:
            vlayout: The main vertical layout.
        """
        # Add a horizontal layout
        hlay_stress_period = gui_util.setup_stress_period_spin_box_layout(
            self, self.uix, self._data.nper(), self.on_spn_period
        )
        vlayout.addLayout(hlay_stress_period)

        # Create the toolbar
        sp_toolbar = self.uix['sp_toolbar'] = QToolBar()
        hlay_stress_period.addWidget(sp_toolbar)

        # Define all action buttons
        define_sp = ActionRec((ADD_SVG, ADD_DIS_SVG), 'Define Period', self.on_btn_define_period)
        delete_sp = ActionRec((DELETE_SVG, DELETE_DIS_SVG), 'Delete Period', self.on_btn_delete_period)
        copy_sp = ActionRec((COPY_SVG, COPY_DIS_SVG), 'Copy Period', self.on_btn_copy_period)
        import_csv = ActionRec((DOWNLOAD_SVG, DOWNLOAD_DIS_SVG), 'Import CSV File', self.on_btn_import_csv_file)
        export_csv = ActionRec((UPLOAD_SVG, UPLOAD_DIS_SVG), 'Export CSV File', self.on_btn_export_csv_file)
        filter_ = ActionRec((FILTER_SVG, FILTER_SVG), 'Filter on Selected Cells', self.on_btn_filter, True)
        dataset = ActionRec((TO_COLUMN_SVG, TO_COLUMN_DIS_SVG), 'Apply Dataset to Column', self._on_tool_set_column)
        sep = ActionRec()  # separator
        # Stuff after the separator is enabled/disabled based on what cell is selected (see _enable_toolbar_buttons())
        plot = ActionRec((PLOT_SVG, PLOT_DIS_SVG), 'Plot All Periods', self.on_btn_plot)
        open_ts = ActionRec((OPEN_SVG, OPEN_DIS_SVG), 'Open Time Series', self.on_btn_open_time_series)

        # Add appropriate action buttons
        if self._include_filter_icon():
            button_list = [define_sp, delete_sp, copy_sp, import_csv, export_csv, filter_, dataset, sep, plot]
        else:
            button_list = [define_sp, delete_sp, copy_sp, import_csv, export_csv, dataset, sep, plot]
        if self._data.get_time_series_columns():  # Include open time series button if necessary
            button_list.append(open_ts)
        self.actions.update(widget_builder.setup_toolbar(sp_toolbar, button_list))

        # Turn on the filter icon, if necessary
        can_filter = self._can_filter_on_selected_cells()
        if FILTER_SVG in self.actions:
            self.actions[FILTER_SVG].setChecked(can_filter)

        # Period undefined text
        w = self.uix['txt_undefined'] = QLabel('(This period is not defined)')
        self.uix['txt_undefined'].setStyleSheet('font-weight: bold; color: red')
        hlay_stress_period.addWidget(w)
        hlay_stress_period.addStretch()

    def _can_filter_on_selected_cells(self) -> bool:
        """Return True if we should filter on selected cells."""
        return self.dlg_input.filter_on_selected_cells and bool(self.dlg_input.selected_cells)

    def _include_filter_icon(self) -> bool:
        """Return True if the filter icon should be included in the toolbar."""
        if 'PERIODS' == self._data.block_with_cellids and self._can_filter_on_selected_cells():
            return True
        return False

    def _set_up_table(self, vlayout: QVBoxLayout):
        """Sets up the table.

        Args:
            vlayout: The main vertical layout.
        """
        # Table
        self.uix['tbl'] = gui_util.new_table_view()
        self.uix['tbl'].allow_drag_fill = True
        vlayout.addWidget(self.uix['tbl'])
        # Context menu
        if not self.dlg_input.locked:
            gui_util.set_vertical_header_menu_method(self.uix['tbl'], self.on_index_column_click)
            gui_util.set_horizontal_header_menu_method(self.uix['tbl'], self._on_horizontal_header_click)
            gui_util.set_table_menu_method(self.uix['tbl'], self.on_right_click)
        # Add combo boxes
        delegate_info = self._data.get_column_delegate_info('')
        if delegate_info:
            for column in delegate_info:
                delegate = QxCbxDelegate(self)
                delegate.set_strings(column[1])
                self.uix['tbl'].setItemDelegateForColumn(column[0], delegate)
        # Find and connect to the top-left corner header button.
        btn_tl = self.uix['tbl'].findChild(QAbstractButton)
        if btn_tl:
            btn_tl.clicked.connect(self.on_select_all)
        # Add a watcher to catch the CTRL+A key press
        self.uix['tbl'].installEventFilter(CtrlAWatcher(func=self.on_select_all))

    def _set_up_table_toolbar(self, vlayout) -> None:
        """Sets up the table toolbar.

        Args:
            vlayout: The main vertical layout.
        """
        # Add horizontal layout
        hlay_table_toolbar = QHBoxLayout()
        vlayout.addLayout(hlay_table_toolbar)

        # Table toolbar
        table_toolbar = self.uix['table_toolbar'] = QToolBar()
        hlay_table_toolbar.addWidget(table_toolbar)

        # Define all action buttons
        add_rows = ActionRec((ROW_ADD_SVG, ROW_ADD_DIS_SVG), 'Add Rows', self.on_btn_add_rows)
        delete_rows = ActionRec((ROW_DELETE_SVG, ROW_DELETE_DIS_SVG), 'Delete Rows', self.on_btn_delete_rows)

        # Add action buttons
        button_list = [add_rows, delete_rows]
        self.actions.update(widget_builder.setup_toolbar(table_toolbar, button_list))

    def setup_db(self):
        """Sets up the database."""
        self.db_filename = database.build(self._data)
        db = QSqlDatabase.addDatabase('QSQLITE', self.db_filename)
        db.setDatabaseName(self.db_filename)
        db.open()
        q = QSqlQuery(db)
        q.exec_('BEGIN TRANSACTION')
        self._queries = QxQueries(db, self)

    def period(self):
        """Returns current stress period in spin control. Shorter for convenience."""
        return self.uix['spn_period'].value()

    def on_right_click(self, point: QPoint):
        """Slot called when user right-clicks in the table.

        Args:
            point: The point clicked.
        """
        # row = self.ui.table_view.logicalIndexAt(point)
        menu_list = [[COPY_SVG, 'Copy', self.uix['tbl'].on_copy], [PASTE_SVG, 'Paste', self.uix['tbl'].on_paste]]
        menu = widget_builder.setup_context_menu(self, menu_list)
        menu.popup(self.uix['tbl'].viewport().mapToGlobal(point))

    def _fix_pasted_data(self):
        """After pasting, fixes the pasted data to have the right PERIOD and PERIOD_ROW info."""
        self._set_period_column(self.period())
        self._renumber_period_row_column()

    def _set_period_column(self, period):
        """Sets all the values in the PERIOD column to period for the current model rows."""
        row_count = self.model.rowCount()
        period_column = self.period_column()
        for row in range(row_count):
            idx = self.model.createIndex(row, period_column)
            self.model.setData(idx, period, Qt.EditRole)

    def _on_horizontal_header_click(self, point: QPoint) -> None:
        """Called on a right-click event in the horizontal header.

        Args:
            point: The point clicked
        """
        column_clicked = self.uix['tbl'].horizontalHeader().logicalIndexAt(point)
        self.uix['tbl'].selectColumn(column_clicked)

        # Build menu
        menu_list = []
        if self._data.block_with_cellids == 'PERIODS' and column_clicked in self._data.get_time_series_columns():
            menu_list.append([TO_COLUMN_SVG, 'Apply Dataset to Column', self._on_menu_set_column])
        if not menu_list:
            return

        menu = widget_builder.setup_context_menu(self, menu_list)
        menu.popup(self.uix['tbl'].viewport().mapToGlobal(point))

    def on_index_column_click(self, point):
        """Called on a right-click event in the index column (vertical header).

        Args:
            point (QPoint): The point clicked
        """
        row = self.uix['tbl'].verticalHeader().logicalIndexAt(point)
        self.uix['tbl'].selectRow(row)
        menu_list = [
            [ROW_INSERT_SVG, 'Insert', self.on_insert_rows],
            [ROW_DELETE_SVG, 'Delete', self.on_btn_delete_rows],
            [COPY_SVG, 'Copy', self.uix['tbl'].on_copy],
            [PASTE_SVG, 'Paste', self.uix['tbl'].on_paste],
        ]
        menu = widget_builder.setup_context_menu(self, menu_list)
        menu.popup(self.uix['tbl'].viewport().mapToGlobal(point))

    def on_modflow_cellid_change(self, index):
        """Called when one of the MODFLOW cellid columns (LAY, ROW, COL, CELL2D, CELLID) changes.

        We update the CELLIDX column.

        Args:
            index (QModelIndex): Index of the cell that changed.
        """
        if self._appending_rows:
            return
        record = self.model.record(index.row())
        if record.contains('CELLIDX'):
            cellidx_index = self.model.createIndex(index.row(), record.indexOf('CELLIDX'))
            cellidx = self.cellidx_from_modflow_cellid(record, 0, 1)
            self.model.setData(cellidx_index, cellidx, Qt.EditRole)
        elif record.contains('CELLIDX1'):  # For HFB
            cellidx1_index = self.model.createIndex(index.row(), record.indexOf('CELLIDX1'))
            cellidx2_index = self.model.createIndex(index.row(), record.indexOf('CELLIDX2'))
            cellidx1 = self.cellidx_from_modflow_cellid(record, 0, 2)
            cellidx2 = self.cellidx_from_modflow_cellid(record, 1, 2)
            self.model.setData(cellidx1_index, cellidx1, Qt.EditRole)
            self.model.setData(cellidx2_index, cellidx2, Qt.EditRole)

    def on_btn_open_time_series(self):
        """If the selected cell contains a time-series, opens the time-series file."""
        selected_list = self.uix['tbl'].selectedIndexes()
        index = selected_list[0]
        time_series = self.model.data(index)
        file = self.find_time_series_file(time_series)
        if not file:
            msg = 'Cannot find time series in list of time series files.'
            message_box.message_with_ok(parent=self, message=msg)
        else:
            os.startfile(file, 'open')

    def on_btn_filter(self):
        """Called when the filter button is clicked."""
        self.dlg_input.filter_on_selected_cells = not self.dlg_input.filter_on_selected_cells
        self.load_period(self.period())

    def find_matching_records(self, row: int, column: int) -> tuple[list[int], list[int], list[str]]:
        """Given a row, column, returns the matching stresses in all periods.

        Args:
            row: A row.
            column: A column.

        Returns:
            (tuple): tuple containing:
                - ids (list): Primary key ids.
                - periods (list): Periods.
                - values (list): Values from column.
        """
        # Get column name and stress_id (stress_id typically is 'CELLIDX')
        record = self.model.record(row)
        stress_id_columns = self._data.stress_id_columns()  # Typically 'CELLIDX'
        stress_ids = [int(record.value(col)) for col in stress_id_columns]
        column_name = self.model.headerData(column, Qt.Horizontal, Qt.DisplayRole)

        if record.contains('BOUNDNAME'):
            ids, periods, values = self.find_stress(row, stress_ids, column_name, str(record.value('BOUNDNAME')))
        else:
            ids, periods, values = self.find_stress(row, stress_ids, column_name, '')

        return ids, periods, values

    def _get_bc_df(self, selected_index: QModelIndex) -> pd.DataFrame:
        """Returns a dataframe containing id (primary key), PERIOD, values to plot for the selected bc.

        Args:
            selected_index: The selected cell.

        Returns:
            See description.
        """
        ids, periods, values = self.find_matching_records(selected_index.row(), selected_index.column())
        column_name = self.model.headerData(selected_index.column(), Qt.Horizontal, Qt.DisplayRole)
        df_dict = {'id': ids, 'PERIOD': periods, column_name: values}
        df = pd.DataFrame(data=df_dict)
        return df

    def _get_bc_df_all_times(self, df: pd.DataFrame) -> pd.DataFrame:
        """Given a dataframe of a bc containing all periods, return one with all period times.

        Args:
            df: Result of _get_bc_df().

        Returns:
            See description.
        """
        if self._period_times is None or self._period_times.empty:
            self._period_times = self._data.mfsim.tdis.get_period_times(as_date_times=True)
        df2 = self._period_times.join(other=df.set_index('PERIOD'), on='PERIOD', how='left', sort=True)
        # Replace nan in y column with XY_UNEDITABLE which signals XySeriesEditor to disable the table cell
        value_col = df2.columns[-1]  # Column with the values (which becomes the y column in the XySeriesEditor)
        df2[value_col] = df2[value_col].fillna(XY_UNEDITABLE)
        return df2

    def on_btn_plot(self):
        """Called when the plot button is clicked."""
        # Get selected cell index
        selection_list = self.uix['tbl'].selectedIndexes()
        if len(selection_list) != 1:
            return
        selected_index = selection_list[0]

        self.model.submitAll()  # Make sure database is up to date before we query

        # Build a dataframe of all periods and values for the selected bc
        df = self._get_bc_df(selected_index)

        # Create a dataframe with all the period times
        df2 = self._get_bc_df_all_times(df)

        # Need a deep copy which includes 'id' and 'PERIOD' columns to check for changes
        df_orig = df2.copy(deep=True)

        # Keep only x and y columns
        del df2['id']
        del df2['PERIOD']

        # Run the XY Series Editor
        gms_icon = QIcon(resources_util.get_resource_path(':/resources/icons/gms.ico'))
        column_name = self.model.headerData(selected_index.column(), Qt.Horizontal, Qt.DisplayRole)
        dialog = XySeriesEditor(
            data_frame=df2,
            series_name=f'{column_name}',
            icon=gms_icon,
            parent=self,
            readonly_x=True,
            can_add_rows=False,
            stair_step=True,
            read_only=self.dlg_input.locked
        )
        if dialog.exec() == QDialog.Accepted:
            # Update the database
            ids, new_values = _get_changes(df_orig, df2)
            if ids:
                self._queries.update_column_values(column_name, ids, new_values)
                self.model.select()

    def _enable_selection_dependent_items(self):
        """Enables/disables stuff based on what's selected in the table."""
        selection_list = self.uix['tbl'].selectedIndexes()
        locked = self.dlg_input.locked
        curr_period = self.period()
        defined = self.is_period_defined(curr_period)

        self.actions[ROW_DELETE_SVG].setEnabled(len(selection_list) > 0 and not locked and defined)

        # Enable plot button
        if PLOT_SVG in self.actions:
            self.actions[PLOT_SVG].setEnabled(False)
            if len(selection_list) == 1:
                column = selection_list[0].column()
                if column in self._data.plottable_columns():
                    self.actions[PLOT_SVG].setEnabled(True)

        # Enable Open Time Series button
        if OPEN_SVG in self.actions:
            self.actions[OPEN_SVG].setEnabled(False)
            if not locked and len(selection_list) == 1:
                # See if they've got a cell selected and it contains a word
                if not util.is_number(self.model.data(selection_list[0])):
                    # See if they've got a selected cell in a column that may
                    # contain time series and if that cell contains a word
                    time_series_columns = self._data.get_time_series_columns()
                    if selection_list[0].column() in time_series_columns:
                        self.actions[OPEN_SVG].setEnabled(True)

    def on_selection_changed(self):
        """Enables/disables stuff based on what's selected in the table."""
        self._enable_selection_dependent_items()

    def _get_old_columns_in_new(self, old_columns, new_columns):
        """Returns the list of the old columns that are in the new columns.

        Args:
            old_columns (list of str): List of the old columns.
            new_columns (list of str): List of the new columns.

        Returns:
            (list[str]): See description.
        """
        # old_columns_in_new = list(set(new_columns) & set(old_columns))  # Order doesn't matter
        old_columns_in_new = []
        for new_column in new_columns:
            for old_column in old_columns:
                if old_column == new_column:
                    old_columns_in_new.append(old_column)
        return old_columns_in_new

    def _alter_table(self, old_columns: list[str], new_columns: list[str], use_aux: bool = True) -> None:
        """Alters the table to include the new columns.

        Args:
            old_columns: List of the old columns.
            new_columns: List of the new columns.
            use_aux: True to include aux.
        """
        create_strings, column_names, orig_column_count = database.get_create_strings(self._data, use_aux=use_aux)
        create_string = ', '.join(create_strings)
        old_columns_in_new = self._get_old_columns_in_new(old_columns, new_columns)
        old_columns_in_new.extend(column_names[orig_column_count:])
        old_columns_in_new_str = ', '.join(old_columns_in_new)
        self._queries.replace_data_table(create_string, old_columns_in_new_str)
        self._queries.create_indexes(self._data, 'data')

    def change_columns(self, old_columns, new_columns, use_aux):
        """Updates AUX in all stress periods if changes were made.

        Args:
            old_columns (list of str): List of old columns.
            new_columns (list of str): List of new columns.
            use_aux (bool): True to include aux.
        """
        # Submit any pending changes
        self.model.submitAll()

        # Set the model to None before we change the database
        self.model.clear()
        self.model = None
        self.uix['tbl'].setModel(None)

        # Change the database and create a new model
        self._alter_table(old_columns, new_columns, use_aux=use_aux)
        self.setup_model(use_aux=use_aux)
        self.load_period(self.period())

    def check_aux_change(self, block: str, new_aux: list[str]) -> str:
        """Return an error message if the new aux list is not valid, or '' if it is.

        PyCharm says this method could be a function, but it's also defined in PeriodArrayWidget and
        ListBlockTableWidget so that wouldn't make sense.

        Args:
            block: Name of the list block.
            new_aux: New aux variable list.

        Returns:
            See description.
        """
        columns, _, _ = self._data.get_column_info('', use_aux=False)
        msg = ''
        if any(aux in columns for aux in new_aux):
            msg = 'Auxiliary variables cannot have the same name as a standard column.'
        return msg

    def change_aux_variables(self, block, use_aux):
        """Updates AUX in all stress periods if changes were made.

        Args:
            block (str): This is only here to match ListBlockTableWidget.change_aux_variables()
            use_aux (bool): True to include AUXILIARY variables.
        """
        old_columns, _, _ = self._data.get_column_info('', use_aux=not use_aux)
        new_columns, _, _ = self._data.get_column_info('', use_aux=use_aux)
        self.change_columns(old_columns, new_columns, use_aux=use_aux)

    def change_boundnames(self, on, use_aux):
        """Adds or removes the boundname column to or from all stress periods."""
        old_columns, _, _ = self._data.get_column_info('', use_aux=use_aux)
        self._data.options_block.set('BOUNDNAMES', on, None)
        new_columns, _, _ = self._data.get_column_info('', use_aux=use_aux)
        self.change_columns(old_columns, new_columns, use_aux=use_aux)

    def previous_defined_period(self):
        """Returns the first defined stress period prior to the current one."""
        for sp in range(self.period() - 1, 0, -1):
            if self.is_period_defined(sp):
                return sp
        return -1

    def next_defined_period(self):
        """Returns the first next stress period after the current one."""
        for sp in range(self.period() + 1, self._data.nper()):
            if self.is_period_defined(sp):
                return sp
        return -1

    def on_btn_delete_period(self):
        """Deletes the current stress period."""
        rv = message_box.message_with_n_buttons(
            parent=self,
            message='Delete this period?',
            button_list=['Delete Period', 'Cancel'],
            default=1,
            escape=1,
            icon='Warning',
            win_icon=util.get_app_icon()
        )
        if rv == 0:
            self._delete_period(self.period())
            self.load_period(self.period())

    def _delete_period(self, period: int) -> None:
        """Delete the period.

        Args:
            period: The period (1-based).
        """
        self._data.period_files.pop(period, None)
        self._queries.delete_period(period)

    def on_btn_copy_period(self):
        """Allows the user to copy the current period to other periods."""
        start_period = self.period()
        defined_periods = self._get_defined_periods_list()
        rv = copy_period_dialog.run_dialog(self.dlg_input, start_period, defined_periods, self)
        if rv is not None:
            self.model.submitAll()  # Save any pending changes
            for i in range(len(rv.copy_to)):
                if i + 1 != rv.copy_from:
                    if rv.copy_to[i]:
                        # If copying from a defined period, copy the period. Otherwise, delete the period.
                        self._delete_period(i + 1)
                        if defined_periods[rv.copy_from - 1]:
                            self.copy_period(rv.copy_from, i + 1)
            self.load_period(start_period)

    def _get_defined_periods_list(self) -> list[int]:
        """Return a list, size of number of periods, with a 0 if not defined, 1 if defined."""
        return [int(self.is_period_defined(sp)) for sp in range(1, self._data.nper() + 1)]

    def on_btn_define_period(self):
        """Allows the user to define data for the current stress period."""
        period = self.period()
        prev_sp = self.previous_defined_period()
        next_sp = self.next_defined_period()
        if prev_sp == -1 and next_sp == -1:
            # No previous or next stress period to copy. Create an empty one.
            self._data.period_files[period] = ''
            self.load_period(period)
            return

        rv = define_stress_period_dialog.run(prev_sp, next_sp, self)
        if rv != -1:
            if rv == 0:
                self.copy_period(prev_sp, period)
            elif rv == 1:
                self.copy_period(next_sp, period)
            else:
                self._data.period_files[period] = ''

            self.load_period(period)

    def _get_export_csv_filepath(self) -> Path | None:
        """Return the csv filepath to export to.

        Returns:
            See description.
        """
        filter_str = 'CSV (Comma delimited) Files (*.csv);;All Files (*.*)'
        filepath = self._last_csv_file if self._last_csv_file else f'{self._data.tree_node.name}.csv'
        filepath = gui_util.run_save_file_dialog(self, 'Export CSV File', filepath, filter_str)
        if filepath:
            self._last_csv_file = filepath
            return Path(filepath)
        return None

    def _get_import_csv_filepath(self) -> Path | None:
        """Return the csv filepath to import.

        Returns:
            See description.
        """
        filter_str = 'CSV (Comma delimited) Files (*.csv);;All Files (*.*)'
        filepath = gui_util.run_open_file_dialog(self, 'Import CSV File', self._last_csv_file, filter_str)
        if filepath:
            self._last_csv_file = filepath
            return Path(filepath)
        return None

    def on_btn_export_csv_file(self) -> None:
        """Export the data to a CSV file."""
        filepath = self._get_export_csv_filepath()
        if not filepath:
            return
        self._export_to_csv(filepath)

    def on_btn_import_csv_file(self) -> None:
        """Import a CSV file."""
        filepath = self._get_import_csv_filepath()
        if not filepath:
            return
        try:
            with dialog_util.wait_cursor_context():
                self._import_from_csv(filepath)
        except ValueError as e:
            message_box.message_with_ok(self, str(e), 'GMS', icon='Critical')

    def _export_to_csv(self, filepath: Path) -> None:
        """Export the table model to a .csv file.

        Args:
            filepath: The filepath.
        """
        self.model.submitAll()
        with filepath.open('w', newline='') as file:
            writer = csv.writer(file)

            # Get fields to write
            rec = self.model.record()
            skip_fields = {'CELLIDX', 'CELLIDX1', 'CELLIDX2', 'PERIOD_ROW', 'id'}
            fields = [rec.fieldName(i) for i in range(rec.count()) if rec.fieldName(i) not in skip_fields]

            writer.writerow(fields)  # Write the header row
            self._queries.write_rows_to_csv(fields, writer)

    def _import_from_csv(self, filepath: Path) -> None:
        """Imports the .csv file.

        Args:
            filepath: The filepath.
        """
        self.model.submitAll()

        # Get fields that we expect
        rec = self.model.record()
        skip_fields = {'CELLIDX', 'CELLIDX1', 'CELLIDX2', 'PERIOD_ROW', 'id'}
        file_fields = [rec.fieldName(i) for i in range(rec.count()) if rec.fieldName(i) not in skip_fields]

        self._queries.clear_errors()
        with filepath.open(newline='') as file:
            reader = csv.reader(file)

            # Read header
            row = next(reader)
            if row != file_fields:
                raise ValueError('Field names in file do not match table.')

            self._queries.delete_all_table_rows('data')

            ftype = self._data.ftype
            grid_info = self._data.grid_info()
            self._queries.insert_rows_from_csv(ftype, grid_info, reader, 'data')
            if self._queries.errors:
                msg = f'Errors reading file:\n{self._queries.errors[0]}'
                details = '\n'.join(self._queries.errors) if len(self._queries.errors) > 1 else ''
                message_box.message_with_ok(parent=self, message=msg, details=details)

            # Update defined periods
            defined_periods = self._queries.defined_periods()
            for period in defined_periods:
                self._data.period_files[period] = ''
        self.model.select()
        self.do_enabling()

    def copy_period(self, from_sp, to_sp):
        """Copies a stress period.

        Args:
            from_sp (int): The stress period being copied.
            to_sp (int): The stress period being copied to.
        """
        record = self.model.record()
        columns = [record.fieldName(i) for i in range(record.count())]
        columns.pop()  # Remove the id column
        column_str = ', '.join(columns)
        columns[columns.index('PERIOD')] = f'"{to_sp}"'
        new_column_str = ', '.join(columns)
        self._queries.copy_period(column_str, from_sp, new_column_str)
        self._data.period_files[to_sp] = ''

    def cellidx_from_modflow_cellid(self, record: QSqlRecord, which_cellidx: int, cellidx_count: int) -> int:
        """Returns the cell index (0-based) for the CELLIDX column based on MODFLOW CELLID columns.

        Handles the different MODFLOW CELLID columns for the different DIS, DISV, and DISU packages.
        If the table has no MODFLOW CELLID columns, returns 1. If the CELLID doesn't make sense for the grid.

        Args:
            record: The record.
            which_cellidx: 0 typically, 0 or 1 for HFB.
            cellidx_count: 1 typically, 2 for HFB.

        Returns:
            If table lacks cellid, returns 1. If cellid is outside grid ranges, returns -1.
        """
        mfcellid = _mfcellid_from_record(record, which_cellidx, cellidx_count)
        if mfcellid is None:  # This happens with SFR and RNO
            return 1
        else:
            try:
                return self._data.grid_info().cell_index_from_modflow_cellid(mfcellid)
            except ValueError:
                return -1

    def get_modflow_cellid_columns(self):
        """Returns a set of the column indexes where the MODFLOW CELLID is located.

        Returns:
            See description.
        """
        record = self.model.record()
        if record.contains('COL'):
            mf_cellid_columns = {0, 1, 2}
        elif record.contains('CELL2D'):
            mf_cellid_columns = {0, 1}
        elif record.contains('CELLID'):
            mf_cellid_columns = {0}
        elif record.contains('COL1'):  # HFB
            mf_cellid_columns = {0, 1, 2, 3, 4, 5}
        elif record.contains('CELL2D1'):  # HFB
            mf_cellid_columns = {0, 1, 2, 3}
        elif record.contains('CELLID1'):  # HFB
            mf_cellid_columns = {0, 1}
        else:
            mf_cellid_columns = set()
        return mf_cellid_columns

    def record_from_defaults(self, period_row):
        """Creates a QSqlRecord based with the default values for this table.

        Args:
            period_row (int): Number to put in the PERIOD_ROW column.

        Returns:
            (QSqlRecord): See description.
        """
        record = self.model.record()
        _, defaults = self.column_info()
        for field_idx in range(record.count()):
            if record.fieldName(field_idx) in defaults:
                record.setValue(field_idx, defaults[record.fieldName(field_idx)])
        stress_id_columns = self._data.stress_id_columns()  # Typically ['CELLIDX']
        for i, col in enumerate(stress_id_columns):
            record.setValue(col, self.cellidx_from_modflow_cellid(record, i, len(stress_id_columns)))
        record.setValue('PERIOD', self.period())
        record.setValue('PERIOD_ROW', period_row)
        return record

    def period_column(self):
        """Returns the number of the PERIOD_ROW column.

        Returns:
            See description.
        """
        return self.model.columnCount() - 3  # Last column is id (primary key). Penultimate is PERIOD_ROW

    def period_row_column(self):
        """Returns the number of the PERIOD_ROW column.

        Returns:
            See description.
        """
        return self.model.columnCount() - 2  # Last column is id (primary key). Penultimate is PERIOD_ROW

    def get_period_rows(self, unique_rows: list[int]) -> list[int]:
        """Returns the list of numbers from the PERIOD_ROW column in the given model rows.

        Args:
            unique_rows: A set of rows in the model.

        Returns:
            See description.
        """
        period_rows = []
        for row in sorted(unique_rows):
            period_rows.append(self.model.index(row, self.period_row_column()).data())
        return period_rows

    def get_selected_rows(self) -> list[int]:
        """Returns the list of selected row numbers (0-based)."""
        selected_list = self.uix['tbl'].selectedIndexes()
        selected_rows = gui_util.get_unique_selected_rows(selected_list)
        return selected_rows

    def temporarily_turn_off_cell_filtering(self):
        """If filtering on CELLIDX, removes it and returns True.

        Keeps filtering on PERIOD.

        Returns:
            True if we were filtering on CELLIDX, otherwise False.
        """
        if 'CELLIDX' in self.model.filter():
            self.set_filter(period=self.period(), include_cells=False)
            return True
        return False

    def increment_period_rows(self, period_rows):
        """Increments PERIOD_ROW for rows below where we will be inserting a new row.

        Args:
            period_rows: List of PERIOD_ROW where we are inserting rows.
        """
        for i in range(len(period_rows)):
            # Increment the period_row numbers for all rows after the current row
            period_row = period_rows[i]
            end_row = period_rows[i + 1] if i + 1 < len(period_rows) else self.model.rowCount()
            for row in range(period_row, end_row):
                idx = self.model.createIndex(row, self.period_row_column())
                curr_period_row = self.model.data(idx)
                self.model.setData(idx, curr_period_row + i + 1)
        self.model.submitAll()

    def insert_rows(self, period_rows):
        """Inserts rows at locations specified by period_rows.

        Args:
            period_rows: List of PERIOD_ROW where rows are to be inserted.

        Returns:
            List of CELLIDX of newly added rows.
        """
        new_cellidxs = []
        for period_row in period_rows:
            record = self.record_from_defaults(period_row)
            self.model.insertRecord(period_row, record)
            stress_id_columns = self._data.stress_id_columns()
            cell_idx_list = [record.value(col) for col in stress_id_columns]
            new_cellidxs.append(cell_idx_list)
        self.model.submitAll()
        return new_cellidxs

    def restore_cell_filtering(self, changes, new_cellidxs):
        """If changes is True, reimplements filtering on cells and includes the new_cellidxs.

        Args:
            changes (bool): True if we turned off filtering on CELLIDX before.
            new_cellidxs (list of list of int): List of lists of cellidxs that were added.
        """
        if changes:
            self.set_filter(period=self.period(), include_cells=True, add_cellidxs=new_cellidxs)
            self.model.select()

    def reselect_rows(self, selected_rows):
        """Selects the rows in the table that were selected before as indicated by selected_rows.

        Args:
            selected_rows: List of rows.
        """
        for row in selected_rows:
            idx = self.model.createIndex(row, 0)
            self.uix['tbl'].selectionModel().select(idx, QItemSelectionModel.ClearAndSelect | QItemSelectionModel.Rows)

    def _get_cell_count(self) -> int:
        """Return the number of cells in the grid."""
        ugrid_node = tree_util.find_tree_node_by_uuid(self.dlg_input.query.project_tree, self.dlg_input.ugrid_uuid)
        return ugrid_node.num_cells

    def _get_selected_column(self) -> int | None:
        selection_list = self.uix['tbl'].selectedIndexes()
        if not selection_list:
            return None
        return selection_list[0].column()

    def _on_set_column_values(self, cbx_items: dict[str, int]) -> None:
        """Set the column values for the current period from a dataset selected by the user."""
        # Run select dataset dialog
        cell_count = self._get_cell_count()
        col_idx = self._get_selected_column()
        project_tree = self.dlg_input.query.project_tree  # for short
        if cbx_items:  # From tool button - we don't know the column
            # Start with current selected column, if there is one
            cbx_start = '' if col_idx is None else self.model.headerData(col_idx, Qt.Horizontal)
            dset_uuid, col_idx = gui_util.select_dataset_dialog_combo_box(
                self, project_tree, '', cell_count, cbx_items, cbx_start
            )
        else:  # From right-click - we know the column
            dset_uuid = gui_util.select_dataset_dialog(self, project_tree, '', cell_count)
        if not dset_uuid:
            return

        # Get dataset values. We currently don't handle transient - just get timestep 0
        dataset = self.dlg_input.query.item_with_uuid(dset_uuid)
        ts_data = dataset.values[0]
        self._set_column_values_from_dataset(col_idx, ts_data)

    def _on_menu_set_column(self) -> None:
        """Set the column values for the current period from a dataset selected by the user."""
        self._on_set_column_values({})

    def _on_tool_set_column(self) -> None:
        """Set the column values for the current period from a dataset selected by the user."""
        column_idxs = self._data.get_time_series_columns()
        cbx_items = {self.model.headerData(idx, Qt.Horizontal): idx for idx in column_idxs}
        self._on_set_column_values(cbx_items)

    def _set_column_values_from_dataset(self, column_clicked: int, ts_data) -> None:
        """Set column values from dataset.

        Args:
            column_clicked: Index of the column.
            ts_data: Values from the dataset.
        """
        # I'm doing this by iterating over rows, but maybe it should be done using a query?
        stress_id_column = self._data.stress_id_columns()[0]  # Typically 'CELLIDX'
        row_count = self.model.rowCount()  # This takes into account filtering on selected rows
        for row in range(row_count):
            record = self.model.record(row)
            stress_id = int(record.value(stress_id_column))  # 'CELLIDX'
            idx = self.model.createIndex(row, column_clicked)  # Index of table cell we want to change
            self.model.setData(idx, ts_data[stress_id])
        self.model.submitAll()

    def on_insert_rows(self):
        """Inserts rows into the spreadsheet.

        When inserting, only the current period is modified. See append_rows for adding to all periods.
        """
        selected_rows = self.get_selected_rows()
        period_rows = self.get_period_rows(selected_rows)
        changes = self.temporarily_turn_off_cell_filtering()
        self.increment_period_rows(period_rows)
        new_cellidxs = self.insert_rows(period_rows)
        self.restore_cell_filtering(changes, new_cellidxs)
        self.reselect_rows(selected_rows)

    def set_modflow_cellid_columns_from_selected_cells(self, row, record):
        """Sets MODFLOW CELLID columns in record based on selected cells.

        Args:
            row (int): Index into self.dlg_input.selected_cells
            record: The record.
        """
        grid_info = self._data.grid_info()  # for convenience
        mf_cellid = grid_info.modflow_cellid_from_cell_index(self.dlg_input.selected_cells[row])
        if len(self.dlg_input.selected_cells) > row + 1:
            mf_cellid2 = grid_info.modflow_cellid_from_cell_index(self.dlg_input.selected_cells[row + 1])
        else:
            mf_cellid2 = mf_cellid

        if grid_info.dis_enum == DisEnum.DIS:
            if record.contains('LAY'):
                record.setValue('LAY', mf_cellid[0])
                record.setValue('ROW', mf_cellid[1])
                record.setValue('COL', mf_cellid[2])
            else:  # HFB
                record.setValue('LAY1', mf_cellid[0])
                record.setValue('ROW1', mf_cellid[1])
                record.setValue('COL1', mf_cellid[2])
                record.setValue('LAY2', mf_cellid2[0])
                record.setValue('ROW2', mf_cellid2[1])
                record.setValue('COL2', mf_cellid2[2])
        elif grid_info.dis_enum == DisEnum.DISV:
            if record.contains('LAY'):
                record.setValue('LAY', mf_cellid[0])
                record.setValue('CELL2D', mf_cellid[1])
            else:  # HFB
                record.setValue('LAY1', mf_cellid[0])
                record.setValue('CELL2D1', mf_cellid[1])
                record.setValue('LAY2', mf_cellid2[0])
                record.setValue('CELL2D2', mf_cellid2[1])
        else:
            if record.contains('CELLID'):
                record.setValue('CELLID', mf_cellid)
            else:
                record.setValue('CELLID1', mf_cellid)
                record.setValue('CELLID2', mf_cellid2)

    def append_rows(self, period, add_stresses_data: AddStressesData):
        """Appends the data to the end of the file.

        Args:
            period (int): The period.
            add_stresses_data: Data from the AddStressesDialog.
        """
        self._appending_rows = True
        period_rows = self._queries.count_rows_in_period(period)

        for row in range(add_stresses_data.rows_to_add):

            record = self.model.record()

            # Set record data to info from AddStressesDialog
            for field_idx in range(record.count()):
                if record.fieldName(field_idx) in add_stresses_data.row_data:
                    record.setValue(field_idx, add_stresses_data.row_data[record.fieldName(field_idx)])

            if self._can_filter_on_selected_cells():
                self.set_modflow_cellid_columns_from_selected_cells(row, record)

            stress_id_columns = self._data.stress_id_columns()  # Typically 'CELLIDX'
            for i, col in enumerate(stress_id_columns):
                record.setValue(col, self.cellidx_from_modflow_cellid(record, i, len(stress_id_columns)))
            record.setValue('PERIOD', period)
            record.setValue('PERIOD_ROW', period_rows + row)
            record.setValue('id', None)
            self.model.insertRecord(-1, record)
        self._appending_rows = False

    def _dis_cells_are_horizontally_adjacent(self, cellid_0, cellid_1):
        """Returns True if the two cells are horizontally adjacent (DIS package only).

        Args:
            cellid_0: First cellid.
            cellid_1: Second cellid.

        Returns:
            (bool): See description.
        """
        if cellid_0[0] == cellid_1[0]:
            adj_columns = cellid_0[1] == cellid_1[1] and abs(cellid_0[2] - cellid_1[2]) == 1
            adj_rows = cellid_0[2] == cellid_1[2] and abs(cellid_0[1] - cellid_1[1]) == 1
            return adj_columns or adj_rows
        return False

    def _hfb_two_adjacent_cells(self):
        """Warns user and returns False if HFB and not just two selected horizontally adjacent cells.

        Horizontal adjacency is only checked for DIS. User beware if DISV or DISU.

        Returns:
            (bool): See description
        """
        ok = True
        if self._data.ftype != 'HFB6':
            return ok

        # Make sure there's exactly two selected cells
        if len(self.dlg_input.selected_cells) != 2:
            ok = False

        # Make sure they are in the same layer
        if ok:
            grid_info = self._data.grid_info()
            cellid_0 = grid_info.modflow_cellid_from_cell_index(self.dlg_input.selected_cells[0])
            cellid_1 = grid_info.modflow_cellid_from_cell_index(self.dlg_input.selected_cells[1])
            if grid_info.dis_enum in {DisEnum.DIS, DisEnum.DISV}:  # If DISU, no layer so, user beware
                if cellid_0[0] != cellid_1[0]:
                    ok = False

            # If DIS, we can also make sure they are horizontall adjacent
            if ok and grid_info.dis_enum == DisEnum.DIS and not self._dis_cells_are_horizontally_adjacent(
                cellid_0, cellid_1
            ):
                ok = False
        if not ok:
            msg = 'With HFB, exactly two horizontally adjacent cells must be selected to add rows.'
            message_box.message_with_ok(parent=self, message=msg)
        return ok

    def on_btn_add_rows(self):
        """Adds rows to the bottom of the spreadsheet."""
        if not self._hfb_two_adjacent_cells():
            return
        selected_count = len(self.dlg_input.selected_cells) if self.dlg_input.filter_on_selected_cells else 0
        which: WhichPeriods = WhichPeriods.ALL_DEFINED if self._data.stress_id_columns() else WhichPeriods.ONLY_THIS
        dialog = AddStressesDialog(data=self._data, selected_count=selected_count, which_periods=which, parent=self)
        if self._data.stress_id_columns():
            if dialog.exec() == QDialog.Accepted:
                rv = dialog.get_data()
                if rv.which_periods == WhichPeriods.ALL_DEFINED:
                    for period in self._data.period_files.keys():
                        self.append_rows(period, rv)
                else:  # Add to just this period
                    self.append_rows(self.period(), rv)
                self.model.submitAll()
                self.load_period(self.period())
        else:  # OC - just add a row
            rv = dialog.get_data()
            self.append_rows(self.period(), rv)
            self.model.submitAll()
            self.load_period(self.period())

    def find_stress(self, row, stress_ids, column_name, boundname):
        """Returns list of ids, periods and values from rows with nth occurrence of stress_id and BOUNDNAME.

        stress_id is CELLIDX. Over all periods. If a period has less than nth occurrences of stress_id (CELLIDX),
        nothing is returned for that period.

        Args:
            row (int): The selected row.
            stress_ids (list of int): The stress ids (typically [CELLIDX] or [CELLIDX1, CELLIDX2]).
            column_name (str): Column to include in query results.
            boundname (str): The boundname to match. May be ''.

        Returns:
            (tuple): tuple containing:
                - ids (list): Primary key ids.
                - periods (list): Periods.
                - values (list): Values from column.
        """
        stress_id_columns = self._data.stress_id_columns()
        assert len(stress_id_columns) == len(stress_ids)
        record = self.model.record(row)
        row_id = record.value('id')
        nth = self._queries.get_nth_instance_of_stress(row_id, stress_ids, stress_id_columns, boundname, self.period())
        return self._queries.find_stress(stress_ids, stress_id_columns, boundname, column_name, nth)

    def delete_selected_rows_from_all_periods(self) -> list[int]:
        """Deletes the selected cells from all periods by finding matches in other periods."""
        selected_rows = self.get_selected_rows()
        self.model.submitAll()
        ids = []
        for row in set(selected_rows):
            this_ids, _, _ = self.find_matching_records(row, 0)
            ids.extend(this_ids)
        self._queries.delete_rows(ids)
        self.model.select()
        return selected_rows

    def on_btn_delete_rows(self):
        """Called when the delete rows button is clicked."""
        rv = 1  # Just this period
        if self._data.stress_id_columns() and self._data.nper() > 1:  # OC doesn't have stress ID columns so skip this
            rv = message_box.message_with_n_buttons(
                parent=self,
                message='Delete matching stresses from all periods?',
                button_list=['Delete From All Periods', 'Delete From Just This Period', 'Cancel'],
                default=0,
                escape=2,
                icon='Question',
                win_icon=util.get_app_icon()
            )
            if rv == 2:  # Cancel
                return

        with dialog_util.wait_cursor_context():
            changes = False
            # selected_rows = []
            if rv == 0:  # All periods
                changes = True
                # selected_rows = self.delete_selected_rows_from_all_periods()
                self.delete_selected_rows_from_all_periods()
            elif rv == 1:  # Just this period
                changes = True
                # selected_rows = self.delete_selected_rows()
                self.delete_selected_rows()

            if changes:
                self.model.submitAll()
                self.load_period(self.period())

    def delete_selected_rows(self):
        """Deletes the selected rows from the spreadsheet."""
        selected_rows = self.get_selected_rows()
        for row in selected_rows:
            self.model.removeRow(row)
        self.model.submitAll()
        self._renumber_period_row_column()
        self.model.submitAll()
        return selected_rows

    def _renumber_period_row_column(self):
        # Renumber PERIOD_ROW column
        for row in range(0, self.model.rowCount()):
            idx = self.model.createIndex(row, self.period_row_column())
            self.model.setData(idx, row)

    def on_spn_period(self, period):
        """Saves any changes to the current stress period and loads the next stress period.

        Args:
            period (int): Stress period number (1-based).
        """
        if self.model:
            self.model.submitAll()  # Save any pending changes
        self.load_period(period)
        self.do_enabling()

    def on_data_changed(self, top_left_index, bottom_right_index):
        """Called when the data in the table view has changed.

        Args:
            top_left_index (QModelIndex): Top left index.
            bottom_right_index (QModelIndex): Bottom right index.
        """
        del top_left_index, bottom_right_index  # Unused parameters
        self.data_changed = True
        self._enable_selection_dependent_items()

    def _load_all_rows(self) -> None:
        """Load all rows of the model when doing large item selections."""
        with dialog_util.wait_cursor_context():
            while self.model.canFetchMore():
                self.model.fetchMore()

    def on_select_column(self, logical_index: int) -> None:
        """Slot for when an entire column is selected.

        Args:
            logical_index (int): 0-based index of the column being selected
        """
        # Load all the rows from the database so we can select them all.
        self._load_all_rows()
        # Now reset the selection because not all the data may have been initially loaded into the table.
        self.uix['tbl'].selectColumn(logical_index)

    def on_select_all(self) -> None:
        """Slot called when the top-left corner header button is clicked."""
        # Load all the rows from the database so we can select them all.
        self._load_all_rows()
        # Now reset the selection because not all the data may have been initially loaded into the table.
        self.uix['tbl'].selectAll()

    def create_time_series_lookup(self):
        """Creates a dictionary of time series and the files they are in.

        Reads through the time series files.
        """
        for time_series_file in self._data.options_block.get('TS6 FILEIN', []):
            time_series_file = fs.resolve_relative_path(self._data.filename, time_series_file)
            if os.path.isfile(time_series_file):
                with open(time_series_file, 'r') as file:
                    for line in file:
                        # Use shlex to handle quoted strings
                        words = shlex.split(line, posix="win" not in sys.platform)
                        if words and words[0] == 'NAME' or words[0] == 'NAMES':
                            for i in range(1, len(words)):
                                self.time_series_lookup[words[i]] = time_series_file
                            break  # Done reading this file

    def find_time_series_file(self, time_series):
        """Given a time series, returns the time series file it is in, or None.

            Builds a lookup dictionary the first time.

        Args:
            time_series (str):

        Returns:
             (str): The time series file path, or None.
        """
        if not self.time_series_lookup:
            self.create_time_series_lookup()
        return self.time_series_lookup.get(time_series)

    def column_info(self):
        """Returns the columns and their defaults.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - default (dict of str -> value): Column names -> default values.
        """
        columns, _, defaults = self._data.get_column_info('')
        return columns, defaults

    def setup_model(self, use_aux=True):
        """Creates a model and connects it to the table view and some signals."""
        # Create the model
        self.model = QxSqlTableModel(self, QSqlDatabase.database(self.db_filename))
        self.model.setTable('data')
        self.model.setEditStrategy(QSqlTableModel.OnManualSubmit)
        self.model.setSort(self.period_row_column(), Qt.AscendingOrder)
        if self.dlg_input.locked:
            read_only_columns = set(range(self.model.columnCount()))
            self.model.set_read_only_columns(read_only_columns)

        self.model.select()

        # Tell the table
        self.uix['tbl'].setModel(self.model)
        self.uix['tbl'].selectionModel().selectionChanged.connect(self.on_selection_changed)

        self._add_mfcellid_delegates()
        column_names = self._add_double_delegates(use_aux)

        # Signals
        self.model.set_watched_columns(self.get_modflow_cellid_columns())
        self.model.watched_column_change.connect(self.on_modflow_cellid_change)
        self.model.dataChanged.connect(self.on_data_changed)
        self.uix['tbl'].pasted.connect(self._fix_pasted_data)
        self.uix['tbl'].horizontalHeader().sectionClicked.connect(self.on_select_column)

        # Hide extra columns
        if not os.path.isfile('C:/temp/debug_mf6_show_all_columns.dbg'):
            data_column_count = len(column_names)
            for column in range(self.model.columnCount()):
                self.uix['tbl'].setColumnHidden(column, column >= data_column_count)

        self.model.set_horizontal_header_tooltips(self._data.get_column_tool_tips(block=''))

    def _add_mfcellid_delegates(self) -> None:
        """Adds delegates for the columns that are components of the MfCellId."""
        # Get fields that are parts of the MfCellId
        record = self.model.record()
        mfcellid_field_names = []
        if record.contains('CELLIDX'):
            mfcellid_field_names = _mfcellid_field_names_for_record(record, 0, 1)
        elif record.contains('CELLIDX1'):  # For HFB
            mfcellid_field_names = _mfcellid_field_names_for_record(record, 0, 2)
            mfcellid_field_names.extend(_mfcellid_field_names_for_record(record, 1, 2))

        # Get the limits
        maxes = {}
        for field in mfcellid_field_names:
            column = record.indexOf(field)
            if 'LAY' in field:
                maxes[column] = self._data.grid_info().nlay
            elif 'ROW' in field:
                maxes[column] = self._data.grid_info().nrow
            elif 'COL' in field:
                maxes[column] = self._data.grid_info().ncol
            elif 'CELL2D' in field:
                maxes[column] = self._data.grid_info().ncpl
            elif 'CELLID' in field:
                maxes[column] = self._data.grid_info().nodes

        # Assign the delegate to the columns
        for column in maxes.keys():
            delegate = TableTextRangeColorDelegate(self, self.model, mn=1, mx=maxes[column])
            self.uix['tbl'].setItemDelegateForColumn(column, delegate)

    def _add_double_delegates(self, use_aux):
        """Adds delegates to the table to handle columns with doubles.

        Why? The default delegate sucks. Also, this should probably be rolled into get_column_delegate_info but that
        only works for combo boxes now. Also, we return the column names because they're used by other code afterwards.

        Args:
            use_aux (bool): True to include aux.

        Returns:
            (list[str]): The column names.
        """
        column_names, column_types, _ = self._data.get_column_info('', use_aux=use_aux)
        for index, name in enumerate(column_names):
            if column_types[name] == np.float64:
                validator = QxDoubleValidator(parent=self)
                delegate = EditFieldValidator(validator, self)
                self.uix['tbl'].setItemDelegateForColumn(index, delegate)
        return column_names

    def is_period_defined(self, period) -> bool:
        """Returns true if the stress period is defined in the list package data.

        Args:
            period (int): Stress period number (1-based).

        Returns:
             (bool): True if stress period is defined, else False.
        """
        if not self._data:
            return False
        return period in self._data.period_files

    def do_enabling(self):
        """Enables and disables the widgets appropriately."""
        not_locked = not self.dlg_input.locked
        curr_period = self.period()
        defined = self.is_period_defined(curr_period)

        # Undefined text
        self.uix['txt_undefined'].setVisible(not defined)

        # Stress period toolbar
        self.actions[ADD_SVG].setEnabled(not_locked and not defined)
        self.actions[DELETE_SVG].setEnabled(not_locked and defined)
        self.actions[COPY_SVG].setEnabled(not_locked)
        self.actions[DOWNLOAD_SVG].setEnabled(not_locked)

        # Handle filtering
        filter_enabled = False
        if FILTER_SVG in self.actions:
            self.actions[FILTER_SVG].setEnabled(defined)
            filter_enabled = True
        self.uix['txt_filtering'].setVisible(filter_enabled and self._can_filter_on_selected_cells())

        self.actions[TO_COLUMN_SVG].setEnabled(defined and not_locked)

        # Table
        self.uix['tbl'].setEnabled(defined)

        self.actions[ROW_ADD_SVG].setEnabled(not_locked and defined)

        self.on_selection_changed()  # This calls self._enable_selection_dependent_items()

    def get_external_filename(self, period):
        """Returns the absolute external filename corresponding to the given stress period.

        Args:
            period (int): Stress period number (1-based).

        Returns:
             (str): The full file path.
        """
        external_filename = self._data.period_files.get(period, None)
        if external_filename:
            external_filename = fs.resolve_relative_path(os.path.dirname(self._data.filename), external_filename)
            external_filename = os.path.normpath(external_filename)
        return external_filename

    def set_filter(self, period, include_cells, add_cellidxs=None):
        """Sets the model filter to filter on this period and, maybe, the selected cells.

        Args:
            period (int): The period to filter on.
            include_cells (bool): True if we want to filter on selected cells.
            add_cellidxs (list of list of int): Additional cells to filter on.
        """
        if add_cellidxs is None:
            add_cellidxs = []
        model_filter = f'PERIOD={period}'
        if (
            include_cells and  # noqa W503 (line break after binary operator)
            'PERIODS' == self._data.block_with_cellids  # noqa W503 (line break before binary operator)
            and self.dlg_input.filter_on_selected_cells  # noqa W503 (line break before binary operator)
            and (self.dlg_input.selected_cells or add_cellidxs)  # noqa W503 (line break before binary operator)
        ):  # noqa W503 (line break before binary operator)
            stress_id_columns = self._data.stress_id_columns()
            cellidxs = self.dlg_input.selected_cells
            for cellidx_list in add_cellidxs:
                for cellidx in cellidx_list:
                    cellidxs.append(cellidx)
            cellidxs_list = ", ".join(map(str, cellidxs))
            if len(stress_id_columns) > 1:
                model_filter += f' AND (CELLIDX1 IN ({cellidxs_list}) OR CELLIDX2 IN ({cellidxs_list}))'
            else:
                model_filter += f' AND CELLIDX IN ({cellidxs_list})'
        self.model.setFilter(model_filter)

    def load_period(self, period):
        """Loads the stress period into the table.

        If stress period is not defined, displays "Undefined" label and
        undims the "Define Stress Period" button.

        Args:
            period (int): Stress period number (1-based).
        """
        with dialog_util.wait_cursor_context():
            # Enable/disable table view and "Define Stress Period" button
            defined = self.is_period_defined(period)
            self.uix['tbl'].setEnabled(defined)
            self.actions[ADD_SVG].setEnabled(not defined)
            self.actions[DELETE_SVG].setEnabled(defined)
            self.actions[COPY_SVG].setEnabled(True)
            self.uix['txt_undefined'].setVisible(not defined)

            # self.model.submitAll()
            self.set_filter(period=period, include_cells=True)
            # self.model.select()
            self.do_enabling()
            widget_builder.resize_columns_to_contents(self.uix['tbl'])

    def clean_up_temp_files(self):
        """Deletes all temp files."""
        io_util.clean_up_temp_files(self._data.period_files.values())

    def close_db(self):
        """Closes the database."""
        if self.db_filename:
            db = QSqlDatabase.database(self.db_filename)
            db.close()
            # QSqlDatabase.removeDatabase(self.db_filename)

    def accept(self):
        """Saves the current changes."""
        self.model.submitAll()
        if self.db_filename:
            db = QSqlDatabase.database(self.db_filename)
            q = QSqlQuery(db)
            q.exec_('COMMIT')
            db.close()

    def reject(self):
        """Called when the user clicks Cancel."""
        if self.db_filename:
            db = QSqlDatabase.database(self.db_filename)
            db.close()
            self.clean_up_temp_files()


def _mfcellid_field_names(which_cellidx: int, cellidx_count: int) -> tuple[str, str, str, str, str]:
    """Return the name of the MfCellId fields.

    Args:
        which_cellidx: Which cellidx. Typically, there's only 1, so 0 is typical. Will be 0 or 1 for HFB.
        cellidx_count: How many cellidx fields there are in the record. 1 typically, 2 for HFB.

    Returns:
        See description.
    """
    lay = 'LAY'
    row = 'ROW'
    col = 'COL'
    cell2d = 'CELL2D'
    cellid = 'CELLID'

    if cellidx_count > 1:  # This is done for HFB
        lay += str(which_cellidx + 1)
        row += str(which_cellidx + 1)
        col += str(which_cellidx + 1)
        cell2d += str(which_cellidx + 1)
        cellid += str(which_cellidx + 1)
    return lay, row, col, cell2d, cellid


def _mfcellid_field_names_for_record(record: QSqlRecord, which_cellidx: int, cellidx_count: int) -> list[str]:
    """Return the MfCellId fields that are in the record.

    Args:
        record: A record from the model.
        which_cellidx: 0 typically, 0 or 1 for HFB.
        cellidx_count: 1 typically, 2 for HFB.

    Returns:
        The MODFLOW cell id.
    """
    lay_field, row_field, col_field, cell2d_field, cellid_field = _mfcellid_field_names(which_cellidx, cellidx_count)
    if record.contains(row_field):  # DIS
        return [lay_field, row_field, col_field]
    elif record.contains(cell2d_field):  # DISV
        return [lay_field, cell2d_field]
    elif record.contains(cellid_field):  # DISU
        return [cellid_field]
    else:
        return []


def _mfcellid_from_record(record: QSqlRecord, which_cellidx: int, cellidx_count: int) -> MfCellId | None:
    """Return the MfCellId in the record.

    Args:
        record: A record from the model.
        which_cellidx: 0 typically, 0 or 1 for HFB.
        cellidx_count: 1 typically, 2 for HFB.

    Returns:
        The MODFLOW cell id.
    """
    mfcellid_field_names = _mfcellid_field_names_for_record(record, which_cellidx, cellidx_count)
    if not mfcellid_field_names:
        return None

    try:
        values = [int(record.value(field)) for field in mfcellid_field_names]
    except TypeError:
        return None
    return MfCellId(*values)


def _get_changes(df_orig: pd.DataFrame, df2: pd.DataFrame) -> tuple[list[int], list[Any]]:
    """Return the changes in df2 compared to df_orig as: list of 'id's, list of 'PERIOD's, list of values.

    Args:
        df_orig: Original dataframe.
        df2: Dataframe after XySeriesEditor.

    Returns:
        See description.
    """
    new_values = []
    ids = []
    df2.reset_index(drop=True, inplace=True)  # Start index at 0 so our indexing in here is always consistent
    for df_row in range(df_orig.shape[0]):
        rowid = df_orig.iloc[df_row, DF_ORIG_COL_ID]
        rowid = -1 if np.isnan(rowid) else int(rowid)
        new_value = str(df2.iloc[df_row, DF2_COL_Y])
        old_value = str(df_orig.iloc[df_row, DF_ORIG_COL_Y])
        old_value = '' if old_value == XY_UNEDITABLE else old_value
        if new_value != old_value:
            new_values.append(new_value)
            ids.append(rowid)
    return ids, new_values
