"""Module for MappedCoverageDialog."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import Any, cast

# 2. Third party modules
from PySide2 import QtCore
from PySide2.QtCore import QAbstractTableModel, QModelIndex, Qt
from PySide2.QtWidgets import QWidget

# 3. Aquaveo modules
from xms.guipy.delegates.node_id_button_delegate import NodeIdButtonDelegate
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg

# 4. Local modules
from xms.schism.data.mapped_bc_data import MappedBcData
from xms.schism.data.model import get_model
from xms.schism.gui.mapped_coverage_dialog_ui import Ui_MappedCoverageDialog

NODE_IDS_COLUMN = 0
ELEVATION_COLUMN = 1
FLOW_COLUMN = 2
NAME_COLUMN = 3


class MappedCoverageModel(QAbstractTableModel):
    """A model that can be used with a QTableView to display a mapped coverage's data."""
    def __init__(self, data: MappedBcData):
        """
        Initialize the model.

        Args:
            data: The data to show in the table.
        """
        super().__init__()
        self.open_arcs = data.open_arcs
        self.generic_model_section = get_model().arc_parameters

    def rowCount(self, parent: QModelIndex = None) -> int:  # noqa: N802 - snake case
        """
        Get the number of rows in an item in the table.

        I think table views support cells with multiple rows inside, but we don't need them, so this only really gets
        the number of rows in the table. Everything else is declared to have no rows.

        Args:
            parent: The parent item. An invalid index is interpreted as meaning the table itself.

        Returns:
            The number of rows inside the requested item.
        """
        if parent is None:
            parent = QModelIndex()
        if parent.isValid():
            # It wants the number of rows inside another row. This is not a hierarchical table, so there are none.
            return 0
        else:
            return len(self.open_arcs)

    def columnCount(self, parent: QModelIndex = None) -> int:  # noqa: N802 - snake case
        """
        Get the number of columns in an item in the table.

        I think table views support cells with multiple cells inside, but we don't need them, so this only really gets
        the number of columns in the table. Everything else is declared to have no columns.

        Args:
            parent: The parent item. An invalid index is interpreted as meaning the table itself.

        Returns:
            The number of columns inside the requested item.
        """
        if parent is None:
            parent = QModelIndex()
        if parent.isValid():
            # It wants the number of columns inside another column. This is not a hierarchical table, so there are none.
            return 0
        else:
            return 4

    def data(self, index: QModelIndex, role: int = Qt.DisplayRole) -> Any:
        """
        Get the data for a cell in the table.

        Args:
            index: Index of the cell to get data for.
            role: The role to get data for.

        Returns:
            Data for the cell.
        """
        if role != Qt.DisplayRole:
            return None

        names = {
            ELEVATION_COLUMN: 'wse-type',
            FLOW_COLUMN: 'flow-type',
            NAME_COLUMN: 'name',
        }

        if index.column() not in names:
            return None

        name = names[index.column()]
        values = self.open_arcs[index.row()][1]
        self.generic_model_section.restore_values(values)
        value = self.generic_model_section.group('open').parameter(name).value

        if index.column() == NAME_COLUMN and not value:
            value = '(none)'

        return value

    def flags(self, index: QModelIndex) -> int:
        """
        Get the flags for an item in the table.

        Flags are a combination of flags from Qt, e.g. Qt.ItemIsUserCheckable.

        This effectively just disables checkboxes, which the base class version seems to like applying for some reason.

        Args:
            index: Index of the item to get flags for.

        Returns:
            Flags for the chosen item.
        """
        return super().flags(index) & (~Qt.ItemIsUserCheckable)

    def headerData(self, item: int, orientation: Qt.Orientation, role=QtCore.Qt.DisplayRole):  # noqa: N802 - snake case
        """
        Get the label for a header.

        Args:
            item: Index of the item to get the label for.
            orientation: Whether this is the horizontal or vertical header.
            role: What role the data is for.
        """
        if orientation == QtCore.Qt.Horizontal and role == QtCore.Qt.DisplayRole:
            if item == NODE_IDS_COLUMN:
                return 'Node IDs'
            elif item == ELEVATION_COLUMN:
                return 'Elevation type'
            elif item == FLOW_COLUMN:
                return 'Flow type'
            elif item == NAME_COLUMN:
                return 'Name'
            else:
                raise ValueError(f'Unknown column: {item}')
        return super().headerData(item, orientation, role)


class MappedCoverageDialog(XmsDlg):
    """A dialog for viewing the data in a mapped coverage."""
    def __init__(self, data: MappedBcData, parent: QWidget):
        """
        Initialize the dialog.

        Args:
            data: Mapped data to display.
            parent: Parent window.
        """
        super().__init__(parent, 'xms.schism.gui.mapped_coverage_dialog')
        self.ui = Ui_MappedCoverageDialog()
        self.ui.setupUi(self)

        delegate = NodeIdButtonDelegate(self.get_node_ids, self)
        self.ui.m_table_view.setItemDelegateForColumn(NODE_IDS_COLUMN, delegate)

        self.open_arcs = data.open_arcs
        self.ui.m_table_view.setModel(MappedCoverageModel(data))

    def get_node_ids(self, row: int) -> list[int]:
        """
        Get the node IDs for a row in the table.

        Args:
            row: Index of the row to get IDs for.

        Returns:
            List of node IDs.
        """
        nodes = self.open_arcs[row][0]
        return cast(list[int], nodes)  # Sequence looks like list except for isinstance checks, which we shouldn't do.
