"""Dialog to view the attributes of applied tidal constituents."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from contextlib import suppress
from typing import Any, Optional
import webbrowser

# 2. Third party modules
from PySide2.QtCore import QAbstractTableModel, QModelIndex, Qt
from PySide2.QtWidgets import QDoubleSpinBox, QItemEditorFactory, QTableView, QWidget
import xarray as xr

# 3. Aquaveo modules
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.validators.number_corrector import NumberCorrector  # noqa: We're implementing a validator.
from xms.guipy.validators.qx_double_validator import QxDoubleValidator
from xms.tides.data import tidal_data as td

# 4. Local modules
from xms.schism.external.mapped_tidal_data import MappedTidalData
from xms.schism.external.mapped_tides_dialog_ui import Ui_MappedTidesDialog

FREQUENCY_COLUMN = 0
NODAL_COLUMN = 1
ARGUMENT_COLUMN = 2


class QxDoubleWidgetFactory(QItemEditorFactory):
    """
    A widget factory for a QTableView that uses QxDoubleValidator for numeric fields.

    See xms.guipy.validators.qx_locale for why we can't just use the default one.
    """
    def createEditor(self, user_type: int, parent: QWidget):  # noqa: N802 - snake case
        """
        Create an editor widget.

        Args:
             user_type: The data type the editor is being created for.
             parent: The parent widget.
        """
        widget = super().createEditor(user_type, parent)
        # A StackOverflow answer suggests that values for user_type are defined on QVariant[1], but a release update
        # for PySide2 says that QVariant was removed in favor of just using any Python object[2]. It says when you need
        # a type you can use a string (the type's name) or the type itself.
        #
        # That's all well and good except that user_type was passed in, and it's an int, even when the field is a float.
        # As far as I can tell, there's no way to get the constants that tell you which value of user_type means what.
        # They might be available in C++ on QMetaType[3], but that doesn't seem to be exposed in PySide2 either.
        #
        # For lack of any actually exposed constants, I just had to figure it out for myself. Testing suggests that this
        # usually comes in as 10. [3] suggests this means "string".
        #
        # [1] https://stackoverflow.com/a/51763156
        # [2] https://pyside.github.io/docs/pyside/pysideapi2.html
        # [3] https://doc.qt.io/qt-6/qmetatype.html#Type-enum
        if user_type == 10:
            widget: QDoubleSpinBox
            widget.setValidator(QxDoubleValidator(parent=parent))
        return widget


class MappedTidesPropertiesModel(QAbstractTableModel):
    """
    Model for a QTableView that will display mapped tidal properties.

    Properties are values that are global to the entire dataset. They include things like the equilibrium argument and
    nodal factor. They do not vary by node, only by constituent.
    """
    def __init__(self, tidal_data: MappedTidalData, parent: QWidget):
        """
        Initialize the model.

        Args:
            tidal_data: Data to display in the model.
            parent: Parent widget.
        """
        super().__init__(parent)
        self.tidal_data = tidal_data

    def rowCount(self, parent: QModelIndex = None) -> int:  # noqa: N802 - snake case
        """
        Get the number of rows in an item in the table.

        I think table views support cells with multiple rows inside, but we don't need them, so this only really gets
        the number of rows in the table. Everything else is declared to have no rows.

        Args:
            parent: The parent item. An invalid index is interpreted as meaning the table itself.

        Returns:
            The number of rows inside the requested item.
        """
        if parent is None:
            parent = QModelIndex()
        if parent.isValid():
            # It wants the number of rows inside another row. This is not a hierarchical table, so there are none.
            return 0
        else:
            return len(self.tidal_data.properties['name'])

    def columnCount(self, parent: QModelIndex = None) -> int:  # noqa: N802 - snake case
        """
        Get the number of columns in an item in the table.

        I think table views support cells with multiple cells inside, but we don't need them, so this only really gets
        the number of cells in the table. Everything else is declared to have no columns.

        Args:
            parent: The parent item. An invalid index is interpreted as meaning the table itself.

        Returns:
            The number of columns inside the requested item.
        """
        if parent is None:
            parent = QModelIndex()
        if parent.isValid():
            # It wants the number of columns inside another column. This is not a hierarchical table, so there are none.
            return 0
        else:
            return 3

    def data(self, index: QModelIndex, role: int = Qt.DisplayRole) -> Any:
        """
        Get the data for a cell in the table.

        Args:
            index: Index of the cell to get data for.
            role: The role to get data for.

        Returns:
            Data for the cell.
        """
        if role != Qt.DisplayRole and role != Qt.EditRole:
            return None

        columns = {
            FREQUENCY_COLUMN: 'frequency',
            NODAL_COLUMN: 'factor',
            ARGUMENT_COLUMN: 'argument',
        }

        if index.column() not in columns:
            return None

        column = columns[index.column()]
        row = index.row()

        value = self.tidal_data.properties[column][row].item()

        # This kind of thing should probably be refactored into xmsguipy at some point.
        return NumberCorrector.format_double(value)

    def setData(self, index: QModelIndex, value: Any, role: int = Qt.DisplayRole):  # noqa: N802 - snake case
        """
        Set the data in a cell.

        Args:
            index: Index of the cell to set data for.
            value: Value to assign.
            role: Role to assign data for.
        """
        if role != Qt.EditRole:
            return super().setData(index, value, role)

        columns = {
            FREQUENCY_COLUMN: 'frequency',
            NODAL_COLUMN: 'factor',
            ARGUMENT_COLUMN: 'argument',
        }

        if index.column() not in columns:
            return False

        column = columns[index.column()]
        row = index.row()

        self.tidal_data.properties[column][row] = value
        self.dataChanged.emit(index, index, [role])
        return True

    def flags(self, index: QModelIndex) -> int:
        """
        Get the flags for an item in the table.

        Flags are a combination of flags from Qt, e.g. Qt.ItemIsUserCheckable.

        This effectively just disables checkboxes, which the base class version seems to like applying for some reason.

        Args:
            index: Index of the item to get flags for.

        Returns:
            Flags for the chosen item.
        """
        return (super().flags(index) & (~Qt.ItemIsUserCheckable)) | Qt.ItemIsEditable

    def headerData(self, index: int, orientation: Qt.Orientation, role=Qt.DisplayRole):  # noqa: N802 - snake case
        """
        Get the label for a header.

        Args:
            index: Index of the item to get the label for.
            orientation: Whether this is the horizontal or vertical header.
            role: What role the data is for.
        """
        if role != Qt.DisplayRole:
            return super().headerData(index, orientation, role)

        if orientation == Qt.Horizontal:
            if index == FREQUENCY_COLUMN:
                return 'Frequency'
            elif index == NODAL_COLUMN:
                return 'Nodal Factor'
            elif index == ARGUMENT_COLUMN:
                return 'Equilibrium\nArgument'
            else:
                raise ValueError(f'Unknown column: {index}')
        else:
            return self.tidal_data.properties['name'].values[index]


class MappedTidesValuesModel(QAbstractTableModel):
    """
    A model that can be used with a QTableView to display and edit mapped tidal values data.

    The values are things that vary by node, like velocity and elevation. They do not include things that are global
    to the whole dataset, like the nodal factor and equilibrium argument.

    This model only supports one tidal constituent at a time. Passing it data for both the M2 and K1 constituents at
    the same time, for example, will give confusing results (probably repeated node IDs). The provided dataset should be
    filtered to a single constituent before passing it to this model.
    """
    def __init__(self, tidal_data: xr.Dataset, columns: list[tuple[str, str]], parent: QWidget):
        """
        Initialize the model.

        Args:
            tidal_data: The data to show in the table.
            columns: List of tuples of (name, label) identifying the columns the model's table should display. Names are
                the names of variables in tidal_data, and labels are what should be displayed in the table as the column
                header for that variable.
            parent: Parent widget.
        """
        super().__init__(parent)
        self.tidal_data = tidal_data
        self.column_names = columns

    def rowCount(self, parent: QModelIndex = None) -> int:  # noqa: N802 - snake case
        """
        Get the number of rows in an item in the table.

        I think table views support cells with multiple rows inside, but we don't need them, so this only really gets
        the number of rows in the table. Everything else is declared to have no rows.

        Args:
            parent: The parent item. An invalid index is interpreted as meaning the table itself.

        Returns:
            The number of rows inside the requested item.
        """
        if parent is None:
            parent = QModelIndex()
        if parent.isValid():
            # It wants the number of rows inside another row. This is not a hierarchical table, so there are none.
            return 0
        else:
            name, _label = self.column_names[0]
            return len(self.tidal_data[name])

    def columnCount(self, parent: QModelIndex = None) -> int:  # noqa: N802 - snake case
        """
        Get the number of columns in an item in the table.

        I think table views support cells with multiple cells inside, but we don't need them, so this only really gets
        the number of columns in the table. Everything else is declared to have no columns.

        Args:
            parent: The parent item. An invalid index is interpreted as meaning the table itself.

        Returns:
            The number of columns inside the requested item.
        """
        if parent is None:
            parent = QModelIndex()
        if parent.isValid():
            # It wants the number of columns inside another column. This is not a hierarchical table, so there are none.
            return 0
        else:
            return len(self.column_names)

    def data(self, index: QModelIndex, role: int = Qt.DisplayRole) -> Any:
        """
        Get the data for a cell in the table.

        Args:
            index: Index of the cell to get data for.
            role: The role to get data for.

        Returns:
            Data for the cell.
        """
        if role != Qt.DisplayRole and role != Qt.EditRole:
            return None

        if index.column() >= len(self.column_names):
            return None

        name, _label = self.column_names[index.column()]
        row = index.row()

        value = self.tidal_data[name][row].item()

        return NumberCorrector.format_double(value)

    def setData(self, index: QModelIndex, value: Any, role: int = Qt.DisplayRole):  # noqa: N802 - snake case
        """
        Set the data in a cell.

        Args:
            index: Index of the cell to set data for.
            value: Value to assign.
            role: Role to assign data for.
        """
        if role != Qt.EditRole:
            return super().setData(index, value, role)

        if index.column() >= len(self.column_names):
            return None

        name, _label = self.column_names[index.column()]
        row = index.row()

        self.tidal_data[name][row] = value
        self.dataChanged.emit(index, index, [role])
        return True

    def flags(self, index: QModelIndex) -> int:
        """
        Get the flags for an item in the table.

        Flags are a combination of flags from Qt, e.g. Qt.ItemIsUserCheckable.

        Args:
            index: Index of the item to get flags for.

        Returns:
            Flags for the chosen item.
        """
        base_flags = super().flags(index)
        flags = base_flags & (~Qt.ItemIsUserCheckable)
        if index.column() == 0:
            flags = flags & (~Qt.ItemIsEditable)  # No editing node IDs
        else:
            flags = flags | Qt.ItemIsEditable  # Everything else is editable though
        return flags

    def headerData(self, index: int, orientation: Qt.Orientation, role=Qt.DisplayRole):  # noqa: N802 - snake case
        """
        Get the label for a header.

        Args:
            index: Index of the item to get the label for.
            orientation: Whether this is the horizontal or vertical header.
            role: What role the data is for.
        """
        if role != Qt.DisplayRole or orientation != Qt.Horizontal:
            return super().headerData(index, orientation, role)

        try:
            _name, label = self.column_names[index]
            return label
        except KeyError:
            raise KeyError(f'{index} out of bounds in MappedTidesConstituentModel.headerData')


class MappedTidesDialog(XmsDlg):
    """A dialog for viewing and editing mapped tidal data."""
    def __init__(self, tidal_data: MappedTidalData, parent: Optional[QWidget] = None):
        """Initializes the class, sets up the ui.

        Args:
            tidal_data: Mapped tidal data to display in the dialog.
            parent: The parent window
        """
        super().__init__(parent, self.__module__)
        self.help_url = 'https://www.xmswiki.com/wiki/SMS:Tidal_Constituents'
        self.tidal_data = tidal_data
        self.ui = Ui_MappedTidesDialog()
        self.ui.setupUi(self)
        self.ui.properties_table.itemDelegate().setItemEditorFactory(QxDoubleWidgetFactory())

        self.ui.properties_table.setModel(MappedTidesPropertiesModel(self.tidal_data, self))
        self.tab_datasets = []

        source_index = self.tidal_data.source
        with suppress(KeyError):
            source_label = td.TDB_SOURCES[source_index]
            self.ui.source_label.setText(source_label)

        merged_tables = xr.merge([self.tidal_data.elevation, self.tidal_data.velocity])

        tab_widget = self.ui.value_tabs
        elevation_columns = [('node', 'Node ID'), ('amplitude', 'Amplitude'), ('phase', 'Phase')]
        velocity_columns = elevation_columns + [
            ('amp_x', 'Velocity\nAmplitude X'),
            ('amp_y', 'Velocity\nAmplitude Y'),
            ('phase_x', 'Velocity\nPhase X'),
            ('phase_y', 'Velocity\nPhase Y'),
        ]

        for name in self.tidal_data.elevation['name'].values:
            dset = merged_tables.sel(name=name)
            self.tab_datasets.append(dset)
            if dset['amp_x'].isnull().any():
                dset = dset.drop(['amp_x', 'amp_y', 'phase_x', 'phase_y'])
                columns = elevation_columns
            else:
                columns = velocity_columns

            table = QTableView(tab_widget)
            table.setModel(MappedTidesValuesModel(dset, columns, tab_widget))
            tab_widget.addTab(table, name.upper())

    def help_requested(self):
        """Called when the Help button is clicked."""
        webbrowser.open(self.help_url)

    def accept(self):
        """Save data from dialog on OK."""
        super().accept()
        self._save_constituent_values()

    def _save_constituent_values(self):
        """Update the constituent values Dataset on accept."""
        concatenated = xr.concat(self.tab_datasets, dim='name')
        elevation = concatenated.drop(['amp_x', 'amp_y', 'phase_x', 'phase_y'])
        self.tidal_data.elevation = elevation
        velocity = concatenated.drop(['amplitude', 'phase'])
        self.tidal_data.velocity = velocity
        self.tidal_data.commit()
