"""Class for managing interprocess communication with XMS."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
from functools import cached_property
import logging
from typing import Optional

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import Coverage
from xms.gmi.data_bases.coverage_base_data import CoverageBaseData

# 4. Local modules
from xms.wavewatch3.data.model import get_model
from xms.wavewatch3.data.output_fields import get_output_field_model
from xms.wavewatch3.data.sim_data import SimData

COVERAGE_TYPE_TO_UNIQUE_NAME = {
    'Output Points': 'OutputPointsComponent',
}

LINKABLE_GRIDS = ['TI_MESH2D_PTR', 'TI_UGRID_PTR']


class XmsData:
    """Class for managing interprocess communication with XMS."""
    def __init__(self, query=None):
        """Constructor."""
        self._query = query
        self._logger = logging.getLogger('xms.wavewatch3')
        self._sim_uuid = None
        self._sim_item = None
        self._sim_data = None
        self._global_time = None
        self._do_ugrid = None
        self._cogrid = None
        self._coverages = {
            'Output Points': None,
        }

    @property
    def query(self):
        """Returns the xmsapi interprocess communication object.

        Notes:
            If constructed in a component runner script, Query should be passed in on construction. If constructed in
            an export or import script, creation of the Query can happen here. Only one Query/XmsData object should be
            constructed per script.
        """
        if self._query is None:
            self._query = Query()
        return self._query

    @property
    def sim_uuid(self):
        """Returns the simulation's UUID (not it's component's UUID)."""
        if self._sim_uuid is None:
            sim_item = self.sim_item
            self._sim_uuid = sim_item.uuid
        return self._sim_uuid

    @property
    def sim_item(self):
        """Returns the simulation's TreeNode."""
        if self._sim_item is None:
            sim_uuid = self.query.current_item_uuid()
            sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, sim_uuid)
            if sim_item is None:
                sim_uuid = self.query.parent_item_uuid()
                sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, sim_uuid)
            self._sim_item = sim_item
        return self._sim_item

    @property
    def sim_data(self) -> SimData:
        """Returns the simulation component's SimData."""
        if self._sim_data is None:
            self._logger.info('Retrieving simulation data from SMS...')
            do_comp = self.query.item_with_uuid(self.sim_uuid, model_name='WaveWatch3', unique_name='SimComponent')
            if not do_comp:
                raise RuntimeError('Unable to retrieve simulation data from SMS.')
            self._sim_data = SimData(do_comp.main_file)
        return self._sim_data

    @property
    def sim_data_model_control(self):
        """Returns the model control data."""
        model = get_model()
        values = self.sim_data.global_values
        model.global_parameters.restore_values(values)
        return model.global_parameters

    @property
    def sim_data_output_fields(self):
        """Returns the model control data."""
        model = get_output_field_model()
        values = self.sim_data.output_field_values
        model.model_parameters.restore_values(values)
        return model.model_parameters

    @property
    def global_time(self):
        """Returns the current SMS zero time datetime object."""
        if self._global_time is None:
            self._global_time = self.query.global_time
        return self._global_time

    @property
    def do_ugrid(self):
        """Retrieves the simulation's data_objects UGrid."""
        if self._do_ugrid is None:
            self._logger.info('Retrieving domain mesh from SMS...')
            tree_node = tree_util.descendants_of_type(
                self.sim_item, allow_pointers=True, xms_types=LINKABLE_GRIDS, only_first=True
            )
            if tree_node is None:
                return None
            self._do_ugrid = self.query.item_with_uuid(tree_node.uuid)
            if not self._do_ugrid:
                raise RuntimeError('Unable to retrieve domain mesh from SMS.')
        return self._do_ugrid

    @property
    def cogrid(self):
        """Retrieves the simulation's domain CoGrid."""
        if self._cogrid is None:
            do_ugrid = self.do_ugrid
            if do_ugrid is None:
                return None
            self._cogrid = read_grid_from_file(do_ugrid.cogrid_file)
        return self._cogrid

    @cached_property
    def spectral_coverages(self):
        """Returns the simulation's spectral coverage, if it has one else None."""
        tree_nodes = tree_util.descendants_of_type(
            self.sim_item,
            coverage_type='SPECTRAL',
            model_name='',
            recurse=False,
            allow_pointers=True,
            only_first=False
        )
        if tree_nodes is None:
            return None
        return [self.query.item_with_uuid(node.uuid, generic_coverage=True) for node in tree_nodes]

    @cached_property
    def _output_points_coverage(self):
        """Returns the simulation's Output Points data_objects coverage and component, if it has one else None."""
        from xms.wavewatch3.components.output_points_component import OutputPointsComponent
        tree_node = tree_util.descendants_of_type(
            self.sim_item,
            coverage_type='Output Points',
            model_name='WaveWatch3',
            recurse=False,
            allow_pointers=True,
            only_first=True
        )
        if tree_node is None:
            return None, None
        coverage = self.query.item_with_uuid(tree_node.uuid)
        do_component = self.query.item_with_uuid(
            tree_node.uuid, model_name='WaveWatch3', unique_name='OutputPointsComponent'
        )
        component = OutputPointsComponent(do_component.main_file)
        self.query.load_component_ids(component, points=True)
        return coverage, component.data

    @property
    def output_points_coverage(self):
        """Helper to get the output points coverage."""
        coverage, _data = self._output_points_coverage
        return coverage

    @property
    def output_points_data(self):
        """Helper to get the output points data."""
        _coverage, data = self._output_points_coverage
        return data

    @cached_property
    def _bc_coverage(self) -> tuple[Optional[Coverage], Optional[CoverageBaseData]]:
        """Returns the boundary condition component, if it has one else None."""
        from xms.wavewatch3.components.bc_component import BcComponent
        tree_node = tree_util.descendants_of_type(
            self.sim_item,
            coverage_type='Boundary Conditions',
            model_name='WaveWatch3',
            recurse=False,
            allow_pointers=True,
            only_first=True
        )
        if tree_node is None:
            return None, None
        coverage = self.query.item_with_uuid(tree_node.uuid)
        do_component = self.query.item_with_uuid(tree_node.uuid, model_name='WaveWatch3', unique_name='Bc_Component')
        component = BcComponent(do_component.main_file)
        self.query.load_component_ids(component, arcs=True)
        return coverage, component.data

    @property
    def bc_coverage(self):
        """Helper to get the bc coverage."""
        coverage, _data = self._bc_coverage
        return coverage

    @property
    def bc_data(self):
        """Helper to get the bc coverage data."""
        _coverage, data = self._bc_coverage
        return data

    @property
    def scalar_dataset_tree(self):
        """Returns a filtered tree of scalar datasets from the simulation's attached grid."""
        grid_ptr_node = tree_util.descendants_of_type(
            self.sim_item, allow_pointers=True, xms_types=LINKABLE_GRIDS, only_first=True
        )
        if grid_ptr_node is None:
            return None
        ugrid_uuid = grid_ptr_node.uuid
        grid_node = tree_util.find_tree_node_by_uuid(self.query.copy_project_tree(), ugrid_uuid)
        filtered_grid_tree = tree_util.trim_tree_to_items_of_type(grid_node, xms_types=['TI_SFUNC'])
        return filtered_grid_tree

    @property
    def vector_dataset_tree(self):
        """Returns a filtered tree of vector datasets from the simulation's attached grid."""
        grid_ptr_node = tree_util.descendants_of_type(
            self.sim_item, allow_pointers=True, xms_types=LINKABLE_GRIDS, only_first=True
        )
        if grid_ptr_node is None:
            return None
        ugrid_uuid = grid_ptr_node.uuid
        grid_node = tree_util.find_tree_node_by_uuid(self.query.copy_project_tree(), ugrid_uuid)
        filtered_grid_tree = tree_util.trim_tree_to_items_of_type(grid_node, xms_types=['TI_VFUNC'])
        return filtered_grid_tree

    def unlink_item(self, child_uuid):
        """Unlink an item taken by an WaveWatch3 simulation.

        Args:
            child_uuid (:obj:`str`): The Dataset attr key of the item to unlink
        """
        self.query.unlink_item(self.sim_uuid, child_uuid)
