"""mapping_tables_dialog.py."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules

# 2. Third party modules
import numpy as np
from PySide2.QtCore import QModelIndex, QObject
from PySide2.QtWidgets import QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.constraint import read_grid_from_file
from xms.core.filesystem import filesystem
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.gmi.data.generic_model import GroupSet, Section
from xms.gmi.gui.custom_dialog_base import CustomDialogBase
from xms.gmi.gui.parameter_widgets import widget_name_from_strings
from xms.guipy import file_io_util
from xms.guipy.dialogs import message_box
from xms.guipy.widgets.table_with_tool_bar import TableWithToolBar
from xms.tool_gui.tool_dialog import run_tool_dialog

# 4. Local modules
from xms.gssha.data import data_util, sim_generic_model
from xms.gssha.data import mapping_tables
from xms.gssha.data.sim_generic_model import InfilType, RichardsCOption
from xms.gssha.file_io import io_util
from xms.gssha.gui.mapping_tables_dialog_ui import Ui_dlg_mapping_tables
from xms.gssha.tools.mapping_table_tool import MappingTableTool

# Constants
MAX_UNIQUE_INTS = 100


class MappingTablesDialog(Ui_dlg_mapping_tables, CustomDialogBase):
    """MappingTablesDialog class.

    Uses xmsgmi and a generic model but with a custom .ui file via CustomDialogBase.

    The tool MappingTableTool is run from the dialog, which causes some complications. The new dataset created
    by the tool is not sent back to XMS because the python process doesn't end with the tool. So, new datasets must
    be kept track of and manually added to a copy of the project tree so that they show up in the tree when the user
    selects the button to choose a dataset. The new datasets are added to xms when the dialog exits.
    """
    def __init__(
        self,
        parent: 'QWidget | None',
        section: GroupSet,
        query: Query,
        sim_component,
        ugrid_node: TreeNode,
        dataset_callback,
    ) -> None:
        """Initializes the class.

        Args:
            parent: Parent widget.
            section: Section in the model to display and edit values for.
            query (Query): The query.
            sim_component (SimComponent): The sim component.
            ugrid_node (TreeNode): The linked UGrid.
            dataset_callback: A callback for picking a dataset or getting its label. See `DatasetCallback` for details.
                This can be `None` if there are no dataset parameters.
        """
        super().__init__(
            parent, section, ['mapping_tables'], None, None, 'xms.gssha.gui.mapping_tables_dialog', dataset_callback
        )
        self.setupUi(self)
        self._query = query
        self._ugrid_node = ugrid_node
        self._sim_component = sim_component

        self._project_tree = sim_component.project_tree
        self._co_grid = None
        self._ugrid = None
        self._pages: dict[str, QWidget] = {}  # Dict of page text -> widget
        self._dataset_labels = {}
        self._new_datasets: dict[str, DatasetWriter] = {
        }  # dataset uuid -> dataset.  Datasets created by MappingTableTool
        self._output_file: str = ''

        qt_widgets_map = self._prepare_pages()
        self.setup_widgets(qt_widgets_map)  # Calls base class
        self._init_warnings()
        self._setup_signals()
        self.lst_tables.setCurrentRow(0)

    def _prepare_pages(self) -> dict[str, QWidget]:
        """Prepares the tabs prior to calling setup_widgets().

        For all dataset combo boxes (OPTION parameters), we set the options to the list of datasets,
        which may have changed since the last time we were in the dialog. This must be done before calling
        CustomDialogBase.setup_widgets().
        """
        qt_widgets_map: dict[str, QWidget] = {}  # widget name -> widget
        page_count = self.lst_tables.count()
        for i in range(page_count):
            page_name = self.lst_tables.item(i).text()
            self._pages[page_name] = page = self.stackedWidget.widget(i)  # Setup _pages dict
            page_name_lower = page_name.lower().replace(' ', '_')
            self._add_mapping_table_widgets(page, page_name_lower, qt_widgets_map)
        return qt_widgets_map

    def _add_mapping_table_widgets(self, page: QWidget, page_name_lower: str, qt_widgets_map: dict) -> None:
        """Adds the widgets to the page.

        Args:
            page: The page widget.
            page_name_lower: Lowercase page name (from the list).
            qt_widgets_map: Dict of the widget names -> widgets.
        """
        page.ui = {}
        page.setLayout(QVBoxLayout())

        # 'Index map:' label
        hlay = QHBoxLayout()
        page.layout().addLayout(hlay)
        # page.layout().addChildLayout(hlay)
        w = page.ui['txt_index_map'] = QLabel('Index map:')
        hlay.addWidget(w)

        # Dataset button
        w = page.ui['btn_index_map_dataset'] = QPushButton('Select...')
        hlay.addWidget(w)
        parameter_name = f'{page_name_lower}_index_map_dataset'
        widget_name = widget_name_from_strings(QPushButton, 'mapping_tables', parameter_name)
        qt_widgets_map[widget_name] = w

        # Dataset label
        txt_dataset = page.ui['txt_index_map_dataset'] = QLabel('dataset')
        hlay.addWidget(txt_dataset)
        widget_name = widget_name_from_strings(QLabel, 'mapping_tables', parameter_name)
        qt_widgets_map[widget_name] = txt_dataset
        hlay.addStretch()

        # Table
        table_with_tool_bar = page.ui['tbl'] = TableWithToolBar()
        page.layout().addWidget(table_with_tool_bar)
        parameter_name = f'{page_name_lower}_table'
        widget_name = widget_name_from_strings(TableWithToolBar, 'mapping_tables', parameter_name)
        qt_widgets_map[widget_name] = table_with_tool_bar
        table_with_tool_bar.rows_inserted.connect(self._on_rows_inserted)

        # Generate IDs button
        btn_generate_ids = page.ui['btn_generate_ids'] = QPushButton('Generate IDs')
        table_with_tool_bar.ui.hlay_tool_bar.addWidget(btn_generate_ids)
        btn_generate_ids.clicked.connect(self._on_btn_generate_ids)

        # From Shapefile / Raster button
        btn_from_shapefile_raster = page.ui['btn_from_shapefile_raster'] = QPushButton('From Shapefile / Raster...')
        table_with_tool_bar.ui.hlay_tool_bar.addWidget(btn_from_shapefile_raster)
        btn_from_shapefile_raster.clicked.connect(self._on_btn_from_shapefile_raster)

        table_with_tool_bar.ui.hlay_tool_bar.addStretch()

        # Make it so we can enable/disable the Generate IDs button when the dataset changes
        self._dataset_labels[txt_dataset] = btn_generate_ids
        txt_dataset.installEventFilter(self)

        # Warning label
        w = page.ui['txt_warning'] = QLabel('The depth varying roughness option is turned off in Model Control.')
        w.setStyleSheet(u"QLabel { background-color : rgb(255,174,117); color : rgb(0, 0, 0); }")
        w.setWordWrap(True)
        page.layout().addWidget(w)

    def _setup_signals(self):
        """Sets up the signals."""
        self.lst_tables.currentRowChanged.connect(self._on_lst_row_changed)
        self.btn_model_control.clicked.connect(self._on_btn_model_control)

    def _init_warnings(self) -> None:
        """Initializes the warning messages."""
        section = self._get_global_parameters()

        self._pages['ROUGHNESS'].ui['txt_warning'].setVisible(False)  # This one never has a warning

        tab_name = 'ROUGH_EXP'
        option = section.group('overland_flow').parameter('ROUGH_EXP').value
        warning = 'The ROUGH_EXP option is turned off in Model Control.'
        self._pages[tab_name].ui['txt_warning'].setText(warning)
        self._pages[tab_name].ui['txt_warning'].setVisible(option is False)

        tab_name = 'INTERCEPTION'
        option = section.group('overland_flow').parameter('INTERCEPTION').value
        self._pages[tab_name].ui['txt_warning'].setText('The INTERCEPTION option is turned off in Model Control.')
        self._pages[tab_name].ui['txt_warning'].setVisible(option is False)

        tab_name = 'RETENTION'
        option = section.group('overland_flow').parameter('RETENTION').value
        self._pages[tab_name].ui['txt_warning'].setText('The RETENTION option is turned off in Model Control.')
        self._pages[tab_name].ui['txt_warning'].setVisible(option is False)

        tab_name = 'GREEN_AMPT_INFILTRATION'
        parameter = section.group('infiltration').parameter('infil_type')
        options = [InfilType.INF_REDIST, InfilType.INF_LAYERED_SOIL]
        warning = f'The infiltration option is not "{options[0]}" or "{options[1]}" in Model Control.'
        self._pages[tab_name].ui['txt_warning'].setText(warning)
        self._pages[tab_name].ui['txt_warning'].setVisible(parameter.value not in options)

        tab_name = 'RICHARDS_EQN_INFILTRATION_BROOKS'
        msg = ''
        parameter = section.group('infiltration').parameter('infil_type')
        if parameter.value != InfilType.INF_RICHARDS:
            msg = f'The infiltration option is not "{InfilType.INF_RICHARDS}" in Model Control.'
        parameter = section.group('infiltration').parameter('RICHARDS_C_OPTION')
        if parameter.value != RichardsCOption.BROOKS:
            msg += f' The RICHARDS_C_OPTION is not "{RichardsCOption.BROOKS}" in Model Control.'
        self._pages[tab_name].ui['txt_warning'].setText(msg)
        self._pages[tab_name].ui['txt_warning'].setVisible(msg != '')

        tab_name = 'RICHARDS_EQN_INFILTRATION_HAVERCAMP'
        msg = ''
        parameter = section.group('infiltration').parameter('infil_type')
        if parameter.value != InfilType.INF_RICHARDS:
            msg = f'The infiltration option is not "{InfilType.INF_RICHARDS}" in Model Control.'
        parameter = section.group('infiltration').parameter('RICHARDS_C_OPTION')
        if parameter.value != RichardsCOption.HAVERCAMP:
            msg += f' The RICHARDS_C_OPTION option is not "{RichardsCOption.HAVERCAMP}" in Model Control.'
        self._pages[tab_name].ui['txt_warning'].setText(msg)
        self._pages[tab_name].ui['txt_warning'].setVisible(msg != '')

        tab_name = 'GREEN_AMPT_INITIAL_SOIL_MOISTURE'
        parameter = section.group('infiltration').parameter('infil_type')
        options = [InfilType.INF_REDIST, InfilType.INF_LAYERED_SOIL]
        warning = f'The infiltration option is not "{options[0]}" or "{options[1]}" in Model Control.'
        self._pages[tab_name].ui['txt_warning'].setText(warning)
        self._pages[tab_name].ui['txt_warning'].setVisible(parameter.value not in options)

        tab_name = 'AREA_REDUCTION'
        option = section.group('overland_flow').parameter('AREA_REDUCTION').value
        self._pages[tab_name].ui['txt_warning'].setText('The AREA_REDUCTION option is turned off in Model Control.')
        self._pages[tab_name].ui['txt_warning'].setVisible(option is False)

    def _get_global_parameters(self) -> Section:
        """Returns the global_parameters Section of the simulation generic model."""
        generic_model = sim_generic_model.create(default_values=False)
        global_values = self._sim_component.data.global_values
        generic_model.global_parameters.restore_values(global_values)
        section = generic_model.global_parameters
        return section

    def _on_lst_row_changed(self, row: int) -> None:
        """Called when the current list row changes."""
        self.stackedWidget.setCurrentWidget(self.stackedWidget.widget(row))

    def _ensure_grid_has_been_read(self):
        """Reads the grid if it hasn't been already."""
        if self._co_grid is None:
            do_ugrid = self._query.item_with_uuid(self._ugrid_node.uuid)
            self._co_grid = read_grid_from_file(do_ugrid.cogrid_file)
            self._ugrid = self._co_grid.ugrid

    def _on_rows_inserted(self, parent: QModelIndex, first: int, last: int) -> None:
        """Called when a table row is inserted."""
        # Default the new index value to the lowest, unused int >= 1
        table_with_tool_bar = self._get_current_table()
        model = table_with_tool_bar.ui.table.model()
        df = table_with_tool_bar.get_values()
        sorted_index_values = df['Index value'].sort_values()
        new_index_value = 1
        for index, value in sorted_index_values.items():
            # Don't consider the row we just added, as the new row's default 'Index value' will always be 1
            if index != first + 1 and value == new_index_value:
                new_index_value += 1
        index_value_column = 0
        model.setData(model.index(first, index_value_column), new_index_value)

    def _get_current_table(self) -> TableWithToolBar:
        """Returns the current TableWithToolBar."""
        current_page = self.stackedWidget.currentWidget()
        table_with_tool_bar = current_page.ui['tbl']
        return table_with_tool_bar

    def _on_btn_generate_ids(self) -> None:
        """Called when the 'Generate IDs' button is clicked."""
        # Warn and allow to cancel if there are existing IDs
        rv = 0  # Delete existing IDs
        if self._ids_exist():
            msg = 'Delete the existing IDs? (Selecting no will simply add any IDs that are not already added.)'
            rv = message_box.message_with_n_buttons(self, msg, 'SMS', ['Yes', 'No', 'Cancel'], default=2, escape=2)
            if rv == 2:
                return

        delete_existing = (rv == 0)
        dataset_parameter = mapping_tables.dataset_param(self.lst_tables.currentItem().text())
        dataset_uuid = self.section.group('mapping_tables').parameter(dataset_parameter).value
        table_with_tool_bar = self.stackedWidget.currentWidget().ui['tbl']
        self._generate_ids(dataset_uuid, delete_existing, table_with_tool_bar)

    def _on_btn_from_shapefile_raster(self) -> None:
        """Called when the 'From Shapefile / Raster...' button is clicked.

        The code is more complicated than expected. The new dataset created by the tool is not sent back to XMS because
        the python process doesn't end with the tool. So, new datasets must be kept track of and manually added to a
        copy of the project tree so that they show up in the tree when the user selects the button to choose a dataset.
        The new datasets are added to xms when the dialog exits.
        """
        # Warn and allow to cancel if there are existing IDs
        if self._ids_exist():
            if not message_box.message_with_ok_cancel(self, 'This will delete all existing IDs. Continue?', 'SMS'):
                return

        self._ensure_grid_has_been_read()

        # Get the current table
        mapping_table_name = self.lst_tables.currentItem().text()

        # Run the tool
        tool = MappingTableTool()
        tool.mapping_table_name = mapping_table_name
        tool.co_grid = self._co_grid
        tool.ugrid = self._ugrid
        tool.ugrid_node = self._ugrid_node

        # Set input and output files. Use last output file as input file to run tool from history, if we can.
        if not self._output_file:
            input_file = self._make_tool_input_file()
            self._output_file = filesystem.temp_filename(suffix='.json')
        else:
            input_file = self._output_file

        if run_tool_dialog(self._query, input_file, self._output_file, self, tool):
            # Get the dataset and make sure it's name is unique amongst its future tree siblings
            dataset = tool.run_results.index_map_dataset
            self._new_datasets[dataset.uuid] = dataset
            ugrid_node = tree_util.find_tree_node_by_uuid(self._project_tree, self._ugrid_node.uuid)

            # Set the generic model dataset parameter
            dataset_parameter_name = mapping_tables.dataset_param(mapping_table_name)
            dataset_parameter = self.section.group('mapping_tables').parameter(dataset_parameter_name)
            dataset_parameter.value = dataset.uuid

            # Add the dataset to our copy of the tree since it won't get added until the python process exits
            dataset_tree_node = TreeNode()
            dataset_tree_node.from_dict(self._get_dataset_tree_node_dict(dataset))
            ugrid_node.add_child(dataset_tree_node)
            # Apparently we have to connect the data object parent/child too
            dataset_tree_node.data.SetParent(ugrid_node.data)
            ugrid_node.data.AddChild(dataset_tree_node.data)

            # Update the dataset label in the dialog
            widget_name = widget_name_from_strings(QLabel, 'mapping_tables', dataset_parameter_name)
            dataset_path = tree_util.build_tree_path(self._project_tree, dataset_parameter.value)
            self._qt_widgets_map[widget_name].setText(dataset_path)

            # Update the table
            table_with_tool_bar = self._get_current_table()
            table_with_tool_bar.setup(tool.run_results.table_definition, tool.run_results.df)

            # Set the generic model table parameter
            table_parameter_name = mapping_tables.table_param(mapping_table_name)
            table_parameter = self.section.group('mapping_tables').parameter(table_parameter_name)
            table_parameter.value = table_with_tool_bar.get_values().values.tolist()

    def _on_btn_model_control(self) -> None:
        """Launches Model Control dialog."""
        self._sim_component.open_model_control(self._query, [], self)
        self._init_warnings()

    def _ids_exist(self) -> bool:
        """Returns true if the table already contains IDs."""
        table_with_tool_bar = self._get_current_table()
        return table_with_tool_bar.ui.table.model().rowCount() > 0

    def _get_dataset_tree_node_dict(self, dataset: DatasetWriter) -> dict:
        """Returns a dict suitable for constructing a dataset TreeNode (see TreeNode._from_dict()).

        Args:
            dataset: The dataset (DatasetWriter).

        Returns:
            The dict.
        """
        dataset_dict = {
            'ctype': 'dataset',
            'name': dataset.name,
            'uuid': dataset.uuid,
            'item_typename': 'TI_SCALAR_DSET',
            'is_ptr_item': False,
            'is_active_item': False,
            'check_state': -1,
            'model_name': '',
            'coverage_type': '',
            'unique_name': '',
            'main_file': '',
            'num_cells': 0,
            'num_points': 0,
            'num_times': 1,
            'num_components': 1,
            'num_vals': dataset.num_values,
            'data_location': 'CELL' if dataset.location == 'cell' else 'NODE',
            'children': []
        }
        return dataset_dict

    def _make_tool_input_file(self) -> str:
        """Creates and returns the tool input file."""
        desc = 'Creates an index map and mapping table from the UGrid, soil data, and land use data.',
        json_dict = {
            'class_name': 'MappingTableTool',
            'is_modal': True,
            'module_name': 'xms.gssha.tools.mapping_table_tool',
            'tool_description': desc,
            'tool_name': 'Mapping Table',
            'tool_uuid': 'aca7e927-f8e1-4ee9-920b-eb0742214736',
        }
        file_path = filesystem.temp_filename(suffix='.json')
        file_io_util.write_json_file(json_dict, file_path)
        return str(file_path)

    def _generate_ids(self, dataset_uuid: str, delete_existing: bool, table_with_tool_bar: TableWithToolBar) -> None:
        """Populates the table with a row for each unique integer in the dataset.

        Args:
            dataset_uuid: Uuid of the dataset.
            delete_existing: If True, any existing data is deleted.
            table_with_tool_bar: The table.
        """
        # Get the dataset from the new ones created while in the dialog, or from the original ones via the query
        if dataset_uuid in self._new_datasets:
            dataset_writer = self._new_datasets[dataset_uuid]
            dataset = DatasetReader(self._new_datasets[dataset_uuid].h5_filename, dataset_writer.name)
        else:
            dataset = self._query.item_with_uuid(dataset_uuid)

        # Get unique integers
        self._ensure_grid_has_been_read()
        on_off_cells = data_util.get_on_off_cells(self._co_grid, self._ugrid)
        masked_array = io_util.create_masked_array(dataset, on_off_cells)
        unique_ints = np.unique(masked_array.round(decimals=0)).compressed()

        # If there's a lot, warn the user
        if len(unique_ints) > MAX_UNIQUE_INTS:
            msg = f'There are {str(len(unique_ints))} unique integer values in the dataset. Continue?'
            rv = message_box.message_with_n_buttons(self, msg, 'SMS', ['Yes', 'No'], default=1, escape=1)
            if rv == 1:
                return

        # Delete existing data if necessary
        model = table_with_tool_bar.ui.table.model()
        if delete_existing:
            model.removeRows(0, model.rowCount())

        # Add a row for each unique integer
        old_row_count = model.rowCount()
        model.insertRows(old_row_count, len(unique_ints))
        for i, integer in enumerate(unique_ints):
            model.setData(model.index(old_row_count + i, 0), integer)

    def eventFilter(self, obj, event):  # noqa N802 (function name 'eventFilter' should be lowercase)
        """Gets called for every event this object is connected to.

        Args:
            obj: Qt object associated with the event.
            event: The QEvent object.

        Returns:
            (bool): True if the event was handled.
        """
        if isinstance(obj, QLabel):
            btn_generate_ids = self._dataset_labels.get(obj)
            if btn_generate_ids:
                btn_generate_ids.setEnabled(obj.text() != '(none selected)')
        return QObject.eventFilter(self, obj, event)
