"""A dialog for transport constituent values."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import webbrowser

# 2. Third party modules
from PySide2.QtCore import QModelIndex

# 3. Aquaveo modules
from xms.guipy.delegates.edit_field_validator import EditFieldValidator
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.models.rename_model import RenameModel
from xms.guipy.validators.number_corrector import NumberCorrector  # noqa: AQU103
from xms.guipy.validators.qx_double_validator import QxDoubleValidator

# 4. Local modules
from xms.adh.gui.time_series_editor import TimeSeriesEditor
from xms.adh.gui.transport_constituents_dialog_ui import Ui_TransportConstituentsDialog
from xms.adh.gui.widgets.adh_table_widget import AdhTableWidget


class FilteredRenameModel(RenameModel):
    """A model to rename header titles."""
    def __init__(self, column_names, parent=None):
        """Initializes the filter model.

        Args:
            column_names (list): The column names.
            parent (Something derived from :obj:`QObject`): The parent object.

        """
        super().__init__(column_names, parent)

    def filterAcceptsColumn(self, source_column: int, source_parent: QModelIndex) -> bool:  # noqa: N802
        """Filters out the 'ID' column.

        Args:
            source_column (int): The column index in the source model.
            source_parent (QModelIndex): The parent index of the view.

        Returns:
            True if the model should keep the column, false if the column should be hidden.
        """
        return source_column != 0


class UserConstituentsTableWidget(AdhTableWidget):
    """A table of user defined constituents."""
    def __init__(self, parent, data_frame, select_col, filter_model, column_delegates, next_id):
        """Construct the widget.

        Args:
            parent (Something derived from :obj:`QObject`): The parent object.
            data_frame (pandas.DataFrame): The model data.
            select_col (int): Column to select when adding/removing rows
            filter_model (QSortFilterProxyModel): A model for sorting, filtering, and changing how the data is viewed.
            column_delegates (dict): A dictionary with column index as the key and a QStyledItemDelegate as the value.
            next_id (int): The next constituent id starting number.

        """
        super().__init__(parent, data_frame, select_col, filter_model, column_delegates)
        self.next_id = next_id

    def on_btn_add_row(self):
        """Called when a new row is added to the table."""
        row_idx = self.model.rowCount()
        super().on_btn_add_row()
        new_index = self.model.index(row_idx, 0)
        self.model.setData(new_index, self.next_id)
        self.next_id += 1


class TransportConstituentsDialog(XmsDlg):
    """A dialog for assigning transport constituents."""
    def __init__(self, win_cont, data):
        """Initializes the class, sets up the ui, and writes the model control values.

        Args:
            win_cont (QWidget): Parent window
            data (TransportConstituentsIO): The transport constituents data.

        """
        super().__init__(win_cont, 'xms.adh.gui.transport_constituents_dialog')
        self.param_data = data.param_control
        self.data = data
        self.help_url = 'https://www.xmswiki.com/wiki/SMS:ADH_Sediment_Transport_and_Bed_Layers'

        self.ui = Ui_TransportConstituentsDialog()
        self.ui.setupUi(self)

        # Setup the rename model so our column headers make sense.
        rename_model = FilteredRenameModel(['Constituent name', 'Concentration (ppm)'], self)

        # Set up the delegate
        self.dbl_validator = QxDoubleValidator()
        self.edit_delegate = EditFieldValidator(self.dbl_validator)
        self.number_corrector = NumberCorrector(self)

        user_cons = self.data.user_constituents.to_dataframe()

        # Add the table for user defined constituents.
        self.table = UserConstituentsTableWidget(
            self, user_cons, 1, rename_model, {2: self.edit_delegate}, self.data.info.attrs['next_constituent_id']
        )
        self.table.default_values = [[0, 'Constituent', 0.0]]
        self.table.model.set_default_values(self.table.default_values)
        self.ui.top_layout.insertWidget(self.ui.top_layout.count() - 1, self.table)

        # Set up data in other fields.
        self._set_widget_data()

        self.ui.button_box.helpRequested.connect(self.help_requested)

        self.adjustSize()
        self.resize(self.size().width() * 1.5, self.size().height())

    def _set_widget_data(self):
        # Set salinity values.
        self.ui.salinity_group.setChecked(self.param_data.salinity)
        self.ui.salinity_concentration.setText(str(self.param_data.reference_concentration))
        self.ui.salinity_concentration.setValidator(self.dbl_validator)
        self.ui.salinity_concentration.installEventFilter(self.number_corrector)

        # Set temperature values.
        self.ui.temperature_group.setChecked(self.param_data.temperature)
        self.ui.temperature.setText(str(self.param_data.reference_temperature))
        self.ui.temperature.setValidator(self.dbl_validator)
        self.ui.temperature.installEventFilter(self.number_corrector)
        self.ui.use_air_water_heat_transfer.setChecked(self.param_data.air_water_heat_transfer)

        # Set vorticity values.
        self.ui.vorticity_group.setChecked(self.param_data.vorticity)
        self.ui.vorticity_norm.setText(str(self.param_data.vorticity_normalization))
        self.ui.vorticity_norm.setValidator(self.dbl_validator)
        self.ui.vorticity_norm.installEventFilter(self.number_corrector)
        self.ui.vorticity_as.setText(str(self.param_data.vorticity_as_term))
        self.ui.vorticity_as.setValidator(self.dbl_validator)
        self.ui.vorticity_as.installEventFilter(self.number_corrector)
        self.ui.vorticity_ds.setText(str(self.param_data.vorticity_ds_term))
        self.ui.vorticity_ds.setValidator(self.dbl_validator)
        self.ui.vorticity_ds.installEventFilter(self.number_corrector)

        # Connect the time series buttons so the editor comes up.
        self.ui.dew_point_button.pressed.connect(
            lambda: self._edit_curve(self.data.dew_point_series_id, 'Dew point temperature')
        )
        self.ui.short_wave_button.pressed.connect(
            lambda: self._edit_curve(self.data.short_wave_series_id, 'Short wave radiation')
        )

    def _set_data_from_widgets(self):
        # Set salinity values.
        self.param_data.salinity = self.ui.salinity_group.isChecked()
        self.param_data.reference_concentration = float(self.ui.salinity_concentration.text())

        # Set temperature values.
        self.param_data.temperature = self.ui.temperature_group.isChecked()
        self.param_data.reference_temperature = float(self.ui.temperature.text())
        self.param_data.air_water_heat_transfer = self.ui.use_air_water_heat_transfer.isChecked()

        # Set vorticity values.
        self.param_data.vorticity = self.ui.vorticity_group.isChecked()
        self.param_data.vorticity_normalization = float(self.ui.vorticity_norm.text())
        self.param_data.vorticity_as_term = float(self.ui.vorticity_as.text())
        self.param_data.vorticity_ds_term = float(self.ui.vorticity_ds.text())

        # Get values from the table.
        self.data.user_constituents = self.table.model.data_frame.to_xarray()
        self.data.info.attrs['next_constituent_id'] = self.table.next_id

    def _edit_curve(self, curve_id, series_name):
        """Brings up the XY series editor for a curve.

        Args:
            curve_id (int): The id of the curve.
            series_name (str): The name of the series.
        """
        curve = copy.deepcopy(self.data.time_series[curve_id].time_series)
        time_series_editor = TimeSeriesEditor(curve, series_name=series_name, icon=self.windowIcon(), parent=self)
        if time_series_editor.run():
            self.data.time_series[curve_id].time_series = time_series_editor.series

    def help_requested(self):
        """Called when the Help button is clicked."""
        webbrowser.open(self.help_url)

    def accept(self):
        """Save data from dialog on OK."""
        self._set_data_from_widgets()
        super().accept()


# def main():
#     """Demonstrates a simple use of the dialog.
#
#     """
#     import sys
#     from PySide2.QtWidgets import (QApplication)
#     from xms.adh.data.transport_constituents_io import TransportConstituentsIO
#
#     data = TransportConstituentsIO('C:/Temp/AdH_Transport_Constituents_file.nc')
#     app = QApplication(sys.argv)
#     dialog = TransportConstituentsDialog(None, 'AdH Transport Constituents', data)
#     if dialog.exec():
#         data.param_control = dialog.data
#         data.commit()
#
#     sys.exit(app.exec_())

# if __name__ == "__main__":
#     main()
