"""SimComponent class."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property

# 2. Third party modules
from PySide2.QtWidgets import QWidget

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.components.bases.component_with_menus_base import MessagesAndRequests
from xms.gmi.component_bases.sim_component_base import SimComponentBase
from xms.gmi.data.generic_model import Parameter, Section
from xms.gmi.gui.dataset_callback import DatasetRequest
from xms.gmi.gui.section_dialog import SectionDialog

# 4. Local modules
from xms.wavewatch3.data.model import get_model
from xms.wavewatch3.data.sim_data import SimData
from xms.wavewatch3.dmi.xms_data import LINKABLE_GRIDS, XmsData
from xms.wavewatch3.util.util import calc_timestep_values, populate_spectral_values

SIM_DATA_MAINFILE = 'sim_comp.nc'


class SimComponent(SimComponentBase):
    """A hidden Dynamic Model Interface (DMI) component for the WaveWatch3 model simulation."""
    def __init__(self, main_file):
        """Initializes the component class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file)
        self._section_dialog_keyword_args['enable_unchecked_groups'] = True
        self._section_dialog_keyword_args['hide_checkboxes'] = True
        self.tree_commands.insert(0, ('Output Field Parameters...', self._open_output_field_parameters))
        self.tree_commands.insert(0, ('Calculate Default Timestep...', self._calc_timestep_defaults))
        self.tree_commands.insert(0, ('Populate Spectral Parameters...', self._populate_spectral_defaults))
        self.default_tree_command = self._internal_open_model_control

    def _get_global_parameter_section(self) -> Section:
        return get_model().global_parameters

    @cached_property
    def data(self) -> SimData:
        """The component's data manager."""
        return SimData(self.main_file)

    def _open_output_field_parameters(self, query: Query, _params: list[dict], parent: QWidget) -> MessagesAndRequests:
        """
        Run the output fields dialog.

        Args:
            query: Interprocess communication object.
            parent: Parent widget for the dialog.
        """
        section = self.data.output_fields_data

        dlg = SectionDialog(
            parent=parent,
            section=section,
            dlg_name=f'{self.module_name}.output_field',
            window_title='Output Field',
            **self._section_dialog_keyword_args,
        )
        if dlg.exec():
            values = dlg.section.extract_values()
            self.data.output_field_values = values
            self.data.commit()
        return [], []

    def item_linked(self, query: Query, linked_uuid: str, unique_name: str, lock_state: bool, parent: QWidget):
        """
        Handle when a new item was linked to the simulation.

        Args:
            query: Interprocess communication object.
            linked_uuid: UUID of the item that was just linked.
            unique_name: The unique-name of the item being linked, assuming the item's XML was designed the way this
                component expects.

                The sim component's `<component>` tag should have one `<takes>` tag inside for each type of thing that
                can be linked to the simulation. Each `<takes>` tag should have a `<declare_parameter>` tag inside it.
                The `unique_name` passed to this method will be the value of the `<declare_parameter>` tag for the item
                being linked. The value of this tag should probably be the item's unique-name (typically its class name)
                since it's basically the same thing.
            lock_state: Whether the item is currently locked for editing. Currently only makes sense in GMS.
            parent: Parent widget.
        """
        if unique_name not in ['uGrid', 'uGridMesh', 'cGrid']:
            return
        xms_data = XmsData(query)
        sim_item = xms_data.sim_item

        ugrid_nodes = tree_util.descendants_of_type(sim_item, xms_types=LINKABLE_GRIDS, allow_pointers=True)
        linked_node = next(node for node in ugrid_nodes if node.uuid == linked_uuid)
        if linked_node.item_typename not in LINKABLE_GRIDS:
            # If this crashes it's because we didn't know what string to put into LINKABLE_GRIDS for cGrids
            raise AssertionError  # pragma nocover

        uuids_to_unlink = [node.uuid for node in ugrid_nodes if node.uuid != linked_uuid]

        for uuid_to_unlink in uuids_to_unlink:
            xms_data.unlink_item(uuid_to_unlink)

    @staticmethod
    def _calc_timestep_defaults(
        query: Query, _params: list[dict] | None, _parent: QWidget | None
    ) -> MessagesAndRequests:
        """Calculates default timestep data for the model control."""
        xms_data = XmsData(query)
        if not xms_data.do_ugrid:
            return [('ERROR', 'No grid found linked to the simulation.')], []

        calc_timestep_values(xms_data)

        return [('INFO', 'Default timestep values saved in model control.')], []

    def _populate_spectral_defaults(self, query: Query, _params: list[dict], _parent: QWidget) -> MessagesAndRequests:
        """Calculates spectral default parameters for the model control."""
        xms_data = XmsData(query)
        if not xms_data.spectral_coverages:
            return [('ERROR', 'No spectral coverage found linked to the simulation.')], []
        # We are just using the first spectral coverage for now...
        spectral_coverage = xms_data.spectral_coverages[0]
        populate_spectral_values(xms_data, spectral_coverage)
        return [('INFO', 'Default spectral parameters saved in model control.')], []

    def _dataset_callback(self, request: DatasetRequest, parameter: Parameter) -> str | TreeNode:  # pragma: no cover
        """
        Handle a request for information when picking a dataset. This is an override from xms.gmi.
        """
        scalar_names = [
            'water_level_dataset', 'air_density_dataset', 'atm_momentum_dataset', 'ice_concentration',
            'ice_param_1_dataset', 'ice_param_2_dataset', 'ice_param_3_dataset', 'ice_param_4_dataset',
            'ice_param_5_dataset', 'mud_density_dataset', 'mud_thickness_dataset', 'mud_viscosity_dataset'
        ]
        vector_names = ['currents_dataset', 'winds_dataset']
        if request == DatasetRequest.GetTree:
            if parameter.parameter_name in scalar_names:
                xms_data = XmsData(query=self._query)
                return xms_data.scalar_dataset_tree
            if parameter.parameter_name in vector_names:
                xms_data = XmsData(query=self._query)
                return xms_data.vector_dataset_tree
        return super()._dataset_callback(request, parameter)
