"""TdisDialog."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
import os

# 2. Third party modules
import pandas
from PySide2.QtCore import QModelIndex
from PySide2.QtWidgets import QDialog, QLabel
from typing_extensions import override

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as fs
from xms.guipy.delegates.check_box_no_text import CheckBoxNoTextDelegate
from xms.guipy.dialogs import message_box
from xms.guipy.models.qx_pandas_table_model import QxPandasTableModel

# 4. Local modules
from xms.mf6.data import time_util
from xms.mf6.data.tdis_data import TdisColumn, TdisData
from xms.mf6.file_io import io_factory
from xms.mf6.file_io.writer_options import WriterOptions
from xms.mf6.gui.ats_dialog import AtsDialog
from xms.mf6.gui.date_time_dialog import DateTimeDialog
from xms.mf6.gui.options_gui import OptionsGui
from xms.mf6.gui.package_dialog_base import PackageDialogBase
from xms.mf6.gui.widgets.list_block_table_widget import ListBlockTableWidget


class TdisDialog(PackageDialogBase):
    """The TDIS package dialog."""
    def __init__(self, dlg_input, parent=None):
        """Initializes the dialog.

        Args:
            dlg_input (DialogInput): Information needed by the dialog.
            parent (Something derived from QWidget): The parent window.
        """
        super().__init__(dlg_input, parent)

        self.updating_end_dates = False
        self.date_time_dialog = None  # This variable needed only for testing
        self.ui_periods = {}
        self._old_start_date_time = ''

        self.setup_ui()

    @override
    def clear_sections(self) -> None:
        """Clear all section widgets."""
        self.ui_periods = {}
        super().clear_sections()

    def define_sections(self):
        """Defines the sections that appear in the list of sections.

        self.sections, and self.default_sections should be set here.
        """
        self.sections = ['COMMENTS', 'OPTIONS', 'PERIODDATA']
        self.default_sections = ['PERIODDATA']

    def setup_section(self, section_name):
        """Sets up a section of widgets.

        Args:
            section_name (str): name of the section
        """
        if section_name == 'PERIODDATA':
            self.setup_perioddata_section()
        else:
            super().setup_section(section_name)

    # @overrides
    def setup_options(self, vlayout):
        """Sets up the options section, which is defined dynamically, not in the ui file.

        Args:
            vlayout (QVBoxLayout): The layout that the option widgets will be added to.
        """
        self._sanitize_time_units()
        self.options_gui = OptionsGui(self)
        self.options_gui.setup(vlayout)
        self._old_start_date_time = self.options_gui.uix['edt_start_date_time'].text()

    def _sanitize_time_units(self):
        """Makes the time units match what we expect."""
        units = self.dlg_input.data.options_block.get('TIME_UNITS')
        if units:
            # Make upper case
            units = units.upper()
            self.dlg_input.data.options_block.set('TIME_UNITS', on=True, value=units)

    def setup_perioddata_section(self):
        """Sets up the PERIODDATA section."""
        section = 'PERIODDATA'
        self.add_group_box_to_scroll_area(section)

        # Sto column text
        w = self.ui_periods['txt_sto_data'] = QLabel(
            'The STEADY-STATE column shows period type data from the STO package for convenience.'
        )
        self.uix[section]['layout'].addWidget(w)
        w.setVisible(False)

        # Table
        w = self.ui_periods['tbl_wgt'] = ListBlockTableWidget(self, self.dlg_input, '')
        self.uix[section]['layout'].addWidget(w)

        # Add the stress period info from sto package if provided
        self.add_stress_period_info_from_sto()

        # Model
        self.setup_tdis_model()
        self.update_end_dates()

    def add_stress_period_info_from_sto(self):
        """Reads stress period info from sto package and adds it to the dataframe."""
        for sto_data in self.dlg_input.sto_data_list:
            # Create list of period types to add as a column to the dataframe
            period_types = [0] * self.dlg_input.data.nper()
            for period_idx in range(self.dlg_input.data.nper()):
                period_type = sto_data.stress_periods.get(period_idx + 1, 0)
                period_types[period_idx] = period_type  # (1 = steady-state, 0 = transient)

            # Add the column
            series = pandas.Series(period_types, index=self.dlg_input.data.period_df.index)
            self.dlg_input.data.period_df[f'STEADY-STATE\n{sto_data.mname}'] = series

            # Make sure we show the STO data explanation text
            self.ui_periods['txt_sto_data'].setVisible(True)

    def setup_tdis_model(self):
        """Creates a model and connects it to the table view and some signals."""
        model = QxPandasTableModel(self.dlg_input.data.period_df)
        # if self.updating_end_dates:
        model.set_read_only_columns({TdisColumn.END_DATE})
        _, defaults = self.dlg_input.data.get_column_info('')

        checkbox_columns = set()
        if self.dlg_input.sto_data_list:
            # Setup checkboxes
            check_delegate = CheckBoxNoTextDelegate(self)
            for col_offset in range(len(self.dlg_input.sto_data_list)):
                column = TdisColumn.STEADY_STATE + col_offset
                column_name = self.dlg_input.data.period_df.columns[column]
                defaults[column_name] = 0
                checkbox_columns.add(column)
                self.ui_periods['tbl_wgt'].table.setItemDelegateForColumn(column, check_delegate)
            model.set_checkbox_columns(checkbox_columns)

        model.set_default_values(defaults)
        if self.dlg_input.locked:
            model.set_read_only_columns(set(range(model.columnCount())))
        model.dataChanged.connect(self.update_end_dates)
        self.ui_periods['tbl_wgt'].set_model(model)

        # Hide steady-state columns
        if not self.dlg_input.sto_data_list:
            for col_offset in range(len(self.dlg_input.sto_data_list)):
                self.ui_periods['tbl_wgt'].table.setColumnHidden(TdisColumn.STEADY_STATE + col_offset, True)

    def _update_tool_tips(self, checked: bool, time_units: str):
        # Update column tooltips
        self.dlg_input.data.options_block.set('TIME_UNITS', checked, time_units)
        self.ui_periods['tbl_wgt'].update_tool_tips()

    def on_chk_time_units(self, checked: int):
        """Called when starting date/time checkbox is clicked."""
        self.update_end_dates()
        self._update_tool_tips(bool(checked), self.options_gui.uix['cbx_time_units'].currentText())

    def on_cbx_time_units(self, text: str):
        """Called when the time units changes. Updates the end dates."""
        self.update_end_dates()
        self._update_tool_tips(True, text)

    def on_btn_start_date_time(self):
        """Opens a dialog with a date/time picker and puts the selected date into the starting date string."""
        date, time = self._date_time_from_options()
        self.date_time_dialog = DateTimeDialog(date, time, self)
        if self.date_time_dialog.exec() == QDialog.Accepted:
            self.options_gui.uix['edt_start_date_time'].setText(self.date_time_dialog.get_date_time_string())
            self.update_end_dates()

    def on_chk_start_date_time(self):
        """Called when starting date/time checkbox is clicked."""
        self.update_end_dates()

    def on_chk_ats6_filein(self, checked):
        """Called when the ATS6 FILEIN checkbox is clicked."""
        if checked:
            options_block = self.dlg_input.data.options_block
            option = 'ATS6 FILEIN'
            the_list = options_block.get(option, [])
            if not the_list:
                the_list.append(f'{os.path.splitext(self.dlg_input.data.filename)[0]}.ats')
                options_block.set(option, True, the_list)

    def on_btn_ats6_filein(self):
        """Called when the ATS6 FILEIN button is clicked."""
        # Create a default file if necessary
        options_block = self.dlg_input.data.options_block
        option = 'ATS6 FILEIN'
        the_list = options_block.get(option, [])
        if not os.path.isfile(the_list[0]):
            from_file = os.path.join(os.path.dirname(__file__), '..', 'components/resources/base_model/files/model.ats')
            fs.copyfile(from_file, the_list[0])

        reader = io_factory.reader_from_ftype('ATS6')
        ats_data = reader.read(the_list[0], mfsim=self.dlg_input.data.mfsim, model=self.dlg_input.data.model)
        AtsDialog.run_dialog_on_file(the_list[0], ats_data, self.dlg_input.locked, parent=self)
        # list_dialog.move_files(options_block, option, the_list, self.dlg_input.data.filename)

    def _date_time_from_options(self):
        """Returns a tuple of QDate and QTime by parsing the date/time string from the START_DATE_TIME option.

        Both the date and time will be None of the string cannot be parsed.

        Returns:
            (tuple[QDate,QTime]): See description.
        """
        date = None
        time = None
        date_time_str = self.options_gui.uix['edt_start_date_time'].text()
        if date_time_str:
            q_date_time = time_util.q_date_time_from_arbitrary_string(date_time_str)
            if q_date_time.isValid():
                date = q_date_time.date()
                time = q_date_time.time()
        return date, time

    def parse_starting_date_time(self):
        """Try to parse starting date/time string into a standard format.

        Returns:
            (datetime): The starting date as a QDateTime. May be invalid.
        """
        start_date = self.options_gui.uix['edt_start_date_time'].text()
        if not start_date:
            return 0.0

        date_time = time_util.datetime_from_arbitrary_string(start_date)
        return date_time

    def on_edt_start_date_time(self):
        """Called when the text in the START_DATE_TIME field is edited."""
        date_time = self.parse_starting_date_time()
        if not date_time and self.options_gui.uix['edt_start_date_time'].text():
            message_box.message_with_ok(parent=self, message='Could not parse date/time string.')
            # Put back the old string
            self.options_gui.uix['edt_start_date_time'].setText(self._old_start_date_time)
        self.update_end_dates()

    def update_end_dates(self):
        """Update the end dates in the table."""
        if self.updating_end_dates:
            return
        self.updating_end_dates = True

        # Get units and start date.
        prev_end_date = 0.0
        if self.options_gui.uix['chk_time_units'].isChecked():
            units = self.options_gui.uix['cbx_time_units'].currentText()
            if units != 'UNKNOWN':
                if self.options_gui.uix['chk_start_date_time'].isChecked():
                    prev_end_date = self.parse_starting_date_time()
                else:
                    prev_end_date = 0.0
        else:
            units = ''

        model = self.ui_periods['tbl_wgt'].get_model()
        self.calculate_new_end_dates_and_times(prev_end_date, units)
        model.submit()
        model.dataChanged.emit(QModelIndex(), model.createIndex(model.data_frame.shape[0], TdisColumn.END_DATE))
        self.ui_periods['tbl_wgt'].table.resizeColumnToContents(TdisColumn.END_DATE)
        self.updating_end_dates = False

    def calculate_new_end_dates_and_times(self, prev_end_date, units: str):
        """Calculates new end date/times.

        Args:
            prev_end_date (datetime): The previous end date.
            units: The units
        """
        model = self.ui_periods['tbl_wgt'].get_model()
        durations = model.data_frame.iloc[:, TdisColumn.PERLEN].tolist()
        # yapf: disable
        if (
            self.options_gui.uix['chk_start_date_time'].isChecked()
            and isinstance(prev_end_date, datetime)  # noqa W503 (line break before binary operator)
            and units and units != 'UNKNOWN'  # noqa W503 (line break before binary operator)
        ):  # noqa: W503 (line break)
            # yapf: enable
            times = TdisData.get_date_times(units, prev_end_date, durations)
            for row in range(len(times) - 1):
                end_date_string = times[row + 1].strftime('%Y-%m-%d %H:%M:%S')
                model.data_frame.iat[row, TdisColumn.END_DATE] = end_date_string
        else:
            for row, duration in enumerate(durations):
                prev_end_date += duration
                end_date_string = str(prev_end_date)
                model.data_frame.iat[row, TdisColumn.END_DATE] = end_date_string

    def get_data(self):
        """Returns the self.dlg_input.data object.

        Returns:
            See description.
        """
        return self.dlg_input.data

    @override
    def widgets_to_data(self) -> None:
        """Set dlg_input.data from widgets."""
        super().widgets_to_data()
        if not self.dlg_input.locked:
            model = self.ui_periods['tbl_wgt'].get_model()
            self.dlg_input.data.period_df = model.data_frame
            self.dlg_input.data._nper = model.data_frame.shape[0]

    @staticmethod
    def update_sto_packages(dialog, out_dir):
        """Updates the period data in the STO package files.

        Args:
            dialog: The dialog.
            out_dir (str): Path to output directory.
        """
        col_offset = 0  # Index to get us to the proper STEADY-STATE column
        for sto_data in dialog.dlg_input.sto_data_list:
            # Update the sto package with changes from the dialog
            reader = io_factory.reader_from_ftype('STO6')
            sto_data = reader.read(sto_data.filename, mfsim=sto_data.mfsim, model=sto_data.model)
            sto_data.set_stress_periods_from_list(
                dialog.dlg_input.data.period_df.iloc[:, TdisColumn.STEADY_STATE + col_offset].tolist()
            )
            col_offset += 1
            if out_dir:
                sto_data.filename = os.path.join(out_dir, os.path.basename(sto_data.filename))
            dmi_sim_dir = os.path.normpath(os.path.join(os.path.dirname(sto_data.filename), '..'))
            writer_options = WriterOptions(
                mfsim_dir=out_dir,
                use_open_close=True,
                use_input_dir=False,
                use_output_dir=False,
                dmi_sim_dir=dmi_sim_dir
            )
            writer = io_factory.writer_from_ftype('STO6', writer_options)
            writer.write(sto_data)
