"""A dialog for controlling model options."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import TYPE_CHECKING
import webbrowser

# 2. Third party modules
from adhparam.time_series import TimeSeries
import pandas
from PySide2.QtCore import Qt, QTimer
from PySide2.QtWidgets import (
    QCheckBox, QComboBox, QDialogButtonBox, QHBoxLayout, QLabel, QPushButton, QScrollArea, QTabWidget, QVBoxLayout,
    QWidget
)
import xarray as xr

# 3. Aquaveo modules
import xms.api._xmsapi.dmi as xmd
from xms.api.tree import tree_util
from xms.guipy.dialogs.treeitem_selector import TreeItemSelectorDlg
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg

# 4. Local modules
from xms.adh.data.hot_start_helper import HotStartHelper
from xms.adh.gui.param_qt_helper import ParamQtHelper
from xms.adh.gui.widgets.adh_table_widget import AdhTableWidget
from xms.adh.gui.widgets.data_frame_table import DataFrameTable
from xms.adh.gui.widgets.output_autobuild_table import OutputAutobuildTableWidget
from xms.adh.gui.widgets.output_frequency_table import OutputFrequencyTableWidget

if TYPE_CHECKING:
    from xms.adh.data.xms_query_data import XmsQueryData


class ModelControlDialog(XmsDlg):
    """A dialog for assigning model values."""
    def __init__(self, win_cont: QWidget, title: str, xms_data: 'XmsQueryData'):
        """Initializes the class, sets up the ui, and writes the model control values.

        Args:
            win_cont (QWidget): Parent window.
            title (str): Window title.
            xms_data (XmsData): The data for the model control dialog.
        """
        super().__init__(win_cont, 'xms.adh.gui.model_control_dialog')
        self.xms_data = xms_data
        self.data = xms_data.adh_data.model_control
        self.param_data = self.data.param_control
        self.hot_starts = self.data.get_hot_starts()
        self.hot_starts_data_frame = None
        self.hot_start_helper = None
        self.help_url = 'https://www.xmswiki.com/wiki/SMS:ADH_Model_Control'
        self.param_helper = ParamQtHelper(self)
        self.param_helper.time_series = self.data.time_series
        self.widgets = dict()
        # Read in vessel uuids (if any)
        if 'uuids' in self.data.vessel_uuids:
            self._selected_vessels = self.data.vessel_uuids['uuids'].values.tolist()
        else:
            self._selected_vessels = []

        self.setWindowTitle(title)
        self.setup_ui()
        self.adjustSize()
        self.resize(self.size().width() * 1.5, self.size().height())

    def setup_ui(self):
        """Sets up the widgets and layouts of the dialog."""
        # Dialog QVBoxLayout with QTabWidget then QDialogButtonBox
        self._set_layout('', 'top_layout', QVBoxLayout())
        self.widgets['tab_widget'] = QTabWidget()
        self.widgets['top_layout'].addWidget(self.widgets['tab_widget'])
        self.widgets['btn_box'] = QDialogButtonBox()
        self.widgets['top_layout'].addWidget(self.widgets['btn_box'])

        #  QTabWidget with General, Output, and Advanced tab widgets
        self.widgets['time_scroll'] = QScrollArea()
        self.widgets['time_scroll'].setWidgetResizable(True)
        self.widgets['time_tab'] = QWidget()
        self.widgets['time_scroll'].setWidget(self.widgets['time_tab'])
        self.widgets['tab_widget'].addTab(self.widgets['time_scroll'], 'Time')
        self.widgets['iteration_scroll'] = QScrollArea()
        self.widgets['iteration_scroll'].setWidgetResizable(True)
        self.widgets['iteration_tab'] = QWidget()
        self.widgets['iteration_scroll'].setWidget(self.widgets['iteration_tab'])
        self.widgets['tab_widget'].addTab(self.widgets['iteration_scroll'], 'Iteration')
        self.widgets['operation_scroll'] = QScrollArea()
        self.widgets['operation_scroll'].setWidgetResizable(True)
        self.widgets['operation_tab'] = QWidget()
        self.widgets['operation_scroll'].setWidget(self.widgets['operation_tab'])
        self.widgets['tab_widget'].addTab(self.widgets['operation_scroll'], 'Operation')
        self.widgets['output_scroll'] = QScrollArea()
        self.widgets['output_scroll'].setWidgetResizable(True)
        self.widgets['output_tab'] = QWidget()
        self.widgets['output_scroll'].setWidget(self.widgets['output_tab'])
        self.widgets['tab_widget'].addTab(self.widgets['output_scroll'], 'Output')
        self.widgets['constants_scroll'] = QScrollArea()
        self.widgets['constants_scroll'].setWidgetResizable(True)
        self.widgets['constants_tab'] = QWidget()
        self.widgets['constants_scroll'].setWidget(self.widgets['constants_tab'])
        self.widgets['tab_widget'].addTab(self.widgets['constants_scroll'], 'Constants')
        self.widgets['hotstart_scroll'] = QScrollArea()
        self.widgets['hotstart_scroll'].setWidgetResizable(True)
        self.widgets['hotstart_tab'] = QWidget()
        self.widgets['hotstart_scroll'].setWidget(self.widgets['hotstart_tab'])
        self.widgets['tab_widget'].addTab(self.widgets['hotstart_scroll'], 'Hot Start')
        self.widgets['advanced_scroll'] = QScrollArea()
        self.widgets['advanced_scroll'].setWidgetResizable(True)
        self.widgets['advanced_tab'] = QWidget()
        self.widgets['advanced_scroll'].setWidget(self.widgets['advanced_tab'])
        self.widgets['tab_widget'].addTab(self.widgets['advanced_scroll'], 'Advanced')

        self._setup_ui_time()
        self._setup_ui_iteration()
        self._setup_ui_operation()
        self._setup_ui_output()
        self._setup_ui_constants()
        self._setup_ui_hot_start()
        self._setup_ui_advanced()

        # set all widget values and hide/show
        self.param_helper.do_param_widgets(None)

        # QDialogButtonBox with Ok and Cancel buttons
        self.widgets['btn_box'].setOrientation(Qt.Horizontal)
        self.widgets['btn_box'].setStandardButtons(
            QDialogButtonBox.Cancel | QDialogButtonBox.Ok | QDialogButtonBox.Help
        )
        self.widgets['btn_box'].accepted.connect(self.accept)
        self.widgets['btn_box'].rejected.connect(super().reject)
        self.widgets['btn_box'].helpRequested.connect(self.help_requested)

        self._update_select_vessels_button_state()

    def _setup_ui_time(self):
        self._set_layout('time_tab', 'time_layout', QVBoxLayout())
        self.param_helper.time_series_names.append('max_time_step_size_time_series')
        self.param_helper.add_params_to_layout(self.widgets['time_layout'], self.param_data.time_control)
        self.widgets['time_layout'].addStretch()

    def _setup_ui_iteration(self):
        self._set_layout('iteration_tab', 'iteration_layout', QVBoxLayout())
        self.param_helper.add_params_to_layout(self.widgets['iteration_layout'], self.param_data.iteration_parameters)
        self.widgets['iteration_layout'].addStretch()

    def _setup_ui_operation(self):
        self._set_layout('operation_tab', 'operation_layout', QVBoxLayout())

        self.param_data.operation_parameters.param.transport.precedence = -1
        self.param_data.operation_parameters.param.vessel_entrainment.precedence = -1
        self.param_data.operation_parameters.param.diffusive_wave.precedence = -1  # not in AdH 4.6

        self.param_helper.add_params_to_layout(self.widgets['operation_layout'], self.param_data.operation_parameters)
        self.widgets['operation_layout'].addStretch()

        vessel_checkbox = self.param_helper.get_widget('vessel')

        if vessel_checkbox:
            hbox = QHBoxLayout()
            hbox.addWidget(vessel_checkbox)

            # Create the Select Vessels button
            self.widgets['select_vessels_button'] = QPushButton("Select Vessels...")
            self.widgets['select_vessels_button'].clicked.connect(self._select_vessels)
            vessel_checkbox.clicked.connect(self._update_select_vessels_button_state)
            hbox.addWidget(self.widgets['select_vessels_button'])

            # Replace the vessel checkbox layout with our new one
            self.widgets['operation_layout'].insertLayout(10, hbox)
            QTimer.singleShot(0, self._update_select_vessels_button_state)

    def _update_select_vessels_button_state(self):
        vessel_checkbox = self.param_helper.get_widget('vessel')
        if vessel_checkbox:
            # Enable/disable the button based on the checkbox state
            self.widgets['select_vessels_button'].setEnabled(vessel_checkbox.isChecked())
            if not vessel_checkbox.isChecked():
                # Empty the selected vessels list if the vessels checkbox is unchecked
                self._selected_vessels = []

    def _setup_ui_output(self):
        self._set_layout('output_tab', 'output_tab_layout', QVBoxLayout())
        self.param_data.output_control.param.screen_output_mass_error.precedence = -1  # will use AdH 4.6 wording
        self.param_data.output_control.param.screen_output_worst_nonlinear_node.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.screen_output_worst_linear_node.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.file_output_wind.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.file_output_wave.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.file_output_adapted_grid.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.file_output_adapted_solution.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.file_output_adapted_transport.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.file_output_sediment.precedence = -1  # not in AdH 4.6
        self.param_data.output_control.param.output_flow_strings.precedence = -1  # option moved to bc dialog
        self.param_data.output_control.param.output_control_option.precedence = -1  # will use AdH 4.6 wording
        self.param_data.output_control.param.oc_time_series_id.precedence = -1  # will use a different table
        self.param_data.output_control.param.nodal_output.precedence = -1  # set from Output coverage
        self.param_helper.add_params_to_layout(self.widgets['output_tab_layout'], self.param_data.output_control)
        # create a combobox to select either OC or OS
        self.widgets['oc_os_layout'] = QHBoxLayout()
        self.widgets['oc_os_layout'].addWidget(QLabel('Output control option:'))
        self.widgets['oc_os_combo'] = QComboBox()
        self.widgets['oc_os_combo'].addItems(['Specify output frequency (OC)', 'Specify autobuild (OS)'])
        option_index = self.param_data.output_control.param.output_control_option.objects.index(
            self.param_data.output_control.output_control_option
        )
        self.widgets['oc_os_combo'].setCurrentIndex(option_index)
        self.widgets['oc_os_combo'].currentIndexChanged.connect(self._os_oc_changed)
        self.widgets['oc_os_layout'].addWidget(self.widgets['oc_os_combo'])
        # create a combobox and a table for OC
        self.widgets['oc_units_layout'] = QHBoxLayout()
        self.widgets['oc_units_label'] = QLabel('Output control units:')
        self.widgets['oc_units_layout'].addWidget(self.widgets['oc_units_label'])
        self.widgets['oc_units_combo'] = QComboBox()
        self.widgets['oc_units_combo'].addItems(['seconds', 'minutes', 'hours', 'days', 'weeks'])

        oc_series = self.param_data.output_control.oc_time_series_id
        if oc_series <= 0:
            oc_series = 1 if not self.param_helper.time_series else max(self.param_helper.time_series.keys()) + 1
            self.param_data.output_control.oc_time_series_id = oc_series
            self.param_helper.time_series[oc_series] = TimeSeries()
            self.param_helper.time_series[oc_series].time_series = \
                pandas.DataFrame({'X': pandas.Series([0.0], dtype='float'), 'Y': pandas.Series([0.0], dtype='float')})
        oc_unit = self.param_helper.time_series[oc_series].units
        self.widgets['oc_units_combo'].setCurrentIndex(oc_unit)
        self.widgets['oc_units_layout'].addWidget(self.widgets['oc_units_combo'])
        self.widgets['oc_table'] = OutputFrequencyTableWidget(
            self.widgets['output_tab'], self.param_helper.time_series[oc_series]
        )
        self.widgets['output_tab_layout'].insertLayout(0, self.widgets['oc_os_layout'])
        self.widgets['output_tab_layout'].insertLayout(1, self.widgets['oc_units_layout'])
        self.widgets['output_tab_layout'].insertWidget(2, self.widgets['oc_table'])
        # create a table for OS
        os_series = self.data.info.attrs['os_time_series']
        if os_series <= 0:
            os_series = max(self.param_helper.time_series.keys()) + 1
            self.data.info.attrs['os_time_series'] = os_series
            self.param_helper.time_series[os_series] = TimeSeries()
            self.param_helper.time_series[os_series].series_type = 'SERIES AWRITE'
            self.param_helper.time_series[os_series].time_series = \
                pandas.DataFrame({'START_TIME': pandas.Series([0.0], dtype='float'),
                                  'END_TIME': pandas.Series([0.0], dtype='float'),
                                  'TIME_STEP_SIZE': pandas.Series([0.0], dtype='float'),
                                  'UNITS': pandas.Series([0], dtype='int')})
        self.widgets['os_table'] = OutputAutobuildTableWidget(
            self.widgets['output_tab'], self.param_helper.time_series[os_series]
        )
        self.widgets['output_tab_layout'].insertWidget(3, self.widgets['os_table'])
        self.widgets['meo_check'] = QCheckBox('Screen output mass error active', self.widgets['output_tab'])
        self.widgets['meo_check'].setChecked(self.param_data.output_control.screen_output_mass_error != 0)
        self.widgets['output_tab_layout'].addWidget(self.widgets['meo_check'])
        self._os_oc_changed(option_index)
        self.widgets['output_tab_layout'].addStretch()

    def _setup_ui_constants(self):
        self._set_layout('constants_tab', 'constants_tab_layout', QVBoxLayout())
        self.param_helper.add_params_to_layout(self.widgets['constants_tab_layout'], self.param_data.model_constants)
        self.widgets['constants_tab_layout'].addStretch()

    def _setup_ui_hot_start(self):
        # build table definition
        self.hot_start_helper = HotStartHelper(self.xms_data, self.param_data)
        table_definition = self.hot_start_helper.create_hot_start_table_definition()
        # build data frame
        self.hot_starts_data_frame = self.hot_start_helper.build_data_frame(self.hot_starts)
        # create UI
        self._set_layout('hotstart_tab', 'hotstart_tab_layout', QVBoxLayout())
        hot_start_table = DataFrameTable()
        self.widgets['hotstart_table'] = hot_start_table
        self.widgets['hotstart_tab_layout'].addWidget(hot_start_table)
        hot_start_table.setup(table_definition, self.hot_starts_data_frame)
        self.hot_start_helper.previous_hot_starts = self.hot_starts_data_frame.copy()
        hot_start_table.data_changed.connect(lambda: self.on_hot_start_changed(hot_start_table))

    def _setup_ui_advanced(self):
        self._set_layout('advanced_tab', 'advanced_layout', QVBoxLayout())
        self.widgets['advanced_table'] = AdhTableWidget(
            self.widgets['advanced_tab'], self.param_data.advanced_cards.advanced_cards, 0, None, {}
        )
        self.widgets['advanced_table'].setObjectName('advanced_table')
        self.widgets['advanced_table'].setMinimumHeight(600)
        self.widgets['advanced_table'].setMinimumWidth(500)
        # self.param_helper.add_params_to_layout(self.widgets['advanced_layout'], self.param_data.advanced_cards)
        self.widgets['advanced_layout'].addStretch()

    def on_hot_start_changed(self, hot_start_widget):
        """Get value from hot start table and update.

        Args:
            hot_start_widget: The hot start table widget.
        """
        new_hot_starts = self.hot_starts_data_frame
        self.hot_start_helper.update_time_steps(new_hot_starts)
        hot_start_widget.set_values(new_hot_starts)
        self.hot_start_helper.previous_hot_starts = new_hot_starts.copy()

    def _set_layout(self, parent_name, layout_name, layout):
        """Adds a layout to the parent.

        Args:
            parent_name (str): Name of parent widget in self.widgets or '' for self
            layout_name (QLay): Name of layout in parent widget
            layout (str): QtLayout to be used
        """
        self.widgets[layout_name] = layout
        if parent_name:
            parent = self.widgets[parent_name]
        else:
            parent = self
        if type(parent) in [QVBoxLayout, QHBoxLayout]:
            parent.addLayout(self.widgets[layout_name])
        else:
            parent.setLayout(self.widgets[layout_name])

    def _os_oc_changed(self, index):
        """This is called when the output control option changes.

        Args:
            index (int): The index of the output control combobox.
        """
        is_oc = index == 0
        self.widgets['oc_units_label'].setVisible(is_oc)
        self.widgets['oc_units_combo'].setVisible(is_oc)
        self.widgets['oc_table'].setVisible(is_oc)
        self.widgets['os_table'].setVisible(not is_oc)

    def _get_and_filter_project_tree(self):
        """Get the XMS project explorer tree and filter out non-Vessel coverage items.

        Returns:
            (:obj:`bool`): True if no problems
        """
        # Get the project tree
        self._pe_tree = self.xms_data.project_tree
        if not self._pe_tree:
            self.error_msg = 'Unable to retrieve SMS project explorer tree.'
            self._logger.error(self.error_msg)
            return False

        # Filter the tree for the dataset selector dialog.
        tree_util.filter_project_explorer(self._pe_tree, self.is_vessel_if_coverage)
        return True

    @staticmethod
    def is_vessel_if_coverage(item):
        """Check if a tree item is a Vessel, but only if it is a coverage.

        Args:
            item (:obj:`TreeNode`): The item to check

        Returns:
            (:obj:`bool`): True if the tree item is a Vessel coverage type or is not a coverage.
        """
        if type(item.data) is xmd.CoverageItem:
            if item.model_name == 'AdH' and item.coverage_type == 'Vessel':
                return True
            return False
        return True

    def _select_vessels(self):
        """Display a Vessel coverage selector with multi-select enabled.

        Returns:
            (:obj:`bool`): True if at least one Vessel coverage was selected
        """
        if not self._get_and_filter_project_tree():
            return False

        current_selection_list = self._selected_vessels

        # Display the vessel coverage selector dialog
        selector = TreeItemSelectorDlg(
            title='Select Vessel Coverages',
            target_type=xmd.CoverageItem,
            pe_tree=self._pe_tree,
            parent=None,
            allow_multi_select=True,
            previous_selection=current_selection_list
        )

        if not selector.exec():
            return False

        # Get uuids for selected vessel coverages
        selected_vessels_list = selector.get_selected_item_uuid()
        self._selected_vessels = selected_vessels_list

    def help_requested(self):
        """Called when the Help button is clicked."""
        webbrowser.open(self.help_url)

    def accept(self):
        """Save data from dialog on OK."""
        self.data.time_series = self.param_helper.time_series
        self.param_data.advanced_cards.advanced_cards = self.widgets['advanced_table'].model.data_frame
        self.param_data.output_control.output_control_option = \
            self.param_data.output_control.param.output_control_option.objects[
                self.widgets['oc_os_combo'].currentIndex()]
        oc_series = self.param_data.output_control.oc_time_series_id
        os_series = self.data.info.attrs['os_time_series']
        self.data.time_series[oc_series].units = self.widgets['oc_units_combo'].currentIndex()
        self.data.time_series[oc_series].time_series = self.widgets['oc_table'].model.data_frame
        self.data.time_series[os_series].time_series = self.widgets['os_table'].model.data_frame
        self.param_data.output_control.screen_output_mass_error = self.widgets['meo_check'].isChecked()
        hot_starts = self.hot_start_helper.get_hot_starts(self.hot_starts_data_frame)
        self.data.set_hot_starts(hot_starts)
        self.data.vessel_uuids = xr.Dataset({'uuids': self._selected_vessels})
        return super().accept()
