"""Module for the XmsData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['MISSING', 'XmsData']

# 1. Standard Python modules
from functools import cached_property
from pathlib import Path
from typing import Optional, TypeAlias

# 2. Third party modules
import xarray as xr

# 3. Aquaveo modules
from xms.api.dmi import ModelCheckError, Query
from xms.api.tree import tree_util, TreeNode
from xms.constraint import QuadtreeGrid2d, read_grid_from_file, RectilinearGrid2d, UGrid2d
from xms.data_objects.parameters import Coverage, Projection, UGrid as DoUGrid
from xms.datasets.dataset_reader import DatasetReader
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.cmsflow.components.save_points_component import SavePointsComponent
from xms.cmsflow.data.bc_data import BCData
from xms.cmsflow.data.simulation_data import SimulationData

MISSING = object()

BCComponent: TypeAlias = 'xms.cmsflow.components.bc_component.BCComponent'  # noqa: F821  Dealing with a circular import


class XmsData:
    """Class for wrapping `Query` with a higher-level interface and caching."""

    # PyCharm doesn't know that cached properties are settable.
    # noinspection PyPropertyAccess
    def __init__(
        self,
        query: Optional[Query] = None,
        sim_data: Optional[SimulationData] = None,
        ugrid: Optional[QuadtreeGrid2d | RectilinearGrid2d] = None,
        ugrid_projection: Optional[Projection] = None,
        native_ugrid_projection: Optional[Projection] = None,
        activity_coverage: Optional[Coverage] = None,
        bc_coverage: Optional[tuple[Coverage, 'BCComponent']] = None,
        save_points_coverage: Optional[tuple[Coverage, SavePointsComponent]] = None,
        project_path: str | Path = '',
        ugrids: Optional[dict[str, Optional[UGrid2d]]] = None,
        datasets: Optional[dict[str, Optional[DatasetReader]]] = None,
        tree_paths: Optional[dict[str, str]] = None,
        items_that_exist: Optional[dict[str, bool]] = None,
    ):
        """
        Initialize the class.

        When running in XMS, this should be initialized like `XmsData(query)` with a valid, initialized Query. When
        running under tests, `query` should be None and any other necessary parameters should be filled in instead.

        Most parameters override the attribute with the same name, forcing it to return the provided value instead of
        whatever `query` would have returned. When an overridden value is provided, `query` will not be consulted for
        it. This allows tests to inject their own values without a `Query`.

        Args:
            query: Interprocess communication object.
            sim_data: If provided, overrides the value of `self.sim_data`.
            ugrid: If provided, overrides the value of `self.ugrid`.
            ugrid_projection: If provided, overrides the value of `self.ugrid_projection`.
            native_ugrid_projection: If provided, overrides the value of `self.native_ugrid_projection`.
            activity_coverage: If provided, overrides the value of `self.activity_coverage`.
            bc_coverage: If provided, overrides the value of `self.bc_coverage`.
            save_points_coverage: If provided, overrides the value of `self.save_points_coverage`.
            project_path: If provided, overrides the value of `self.project_path`.
            ugrids: If provided, overrides the internal UGrid cache. `self.get_ugrid()` consults this to find UGrids to
                return. Keys should be the UGrid's UUID, and values should be the UGrid the UUID matches to. `query`
                *will* be consulted for any UUIDs that are not found in this cache, *even if this is provided*. Tests
                that want to cover cases where a UGrid is not found should ensure the missing UGrid's key is in this
                dict and has a value of None.
            datasets: If provided, overrides the internal dataset cache. `self.get_dataset()` consults this to find
                datasets to return. Keys should be the dataset's UUID, and values should be the dataset the UUID matches
                to. `query` *will* be consulted for any UUIDs that are not found in this cache,
                *even if this is provided*. Tests that want to cover cases where a dataset is not found should ensure
                the missing dataset's key is in this dict and has a value of None.
            tree_paths: If provided, overrides the internal tree path cache. `self.tree_path()` consults this to find
                paths to return. Keys should be the item's UUID, and values should be the path that the UUID matches to.
                `query` *will* be consulted for any UUIDs that are not found in this cache, *even if this is provided*.
                Tests that want to cover cases where a path does not exist should ensure the missing item's key is in
                this dict and has a value of an empty string.
            items_that_exist: Dictionary of uuid->exists. If provided, `self.item_exists` will look for items in this
                dictionary before using Query to find them.
        """
        self._query = query
        if sim_data is not None:
            self.sim_data = sim_data
        if ugrid is MISSING:
            self.ugrid = None
        elif ugrid is not None:
            self.ugrid = ugrid
        if ugrid_projection is MISSING:
            self.ugrid_projection = None
        elif ugrid_projection is not None:
            self.ugrid_projection = ugrid_projection
        if native_ugrid_projection is MISSING:
            self.native_ugrid_projection = None
        elif native_ugrid_projection is not None:
            self.native_ugrid_projection = native_ugrid_projection
        if activity_coverage is MISSING:
            self.activity_coverage = None
        elif activity_coverage is not None:
            self.activity_coverage = activity_coverage
        if bc_coverage is MISSING:
            self.bc_coverage = (None, None)
        elif bc_coverage is not None:
            self.bc_coverage = bc_coverage
        self._ugrids = ugrids or {}
        self._datasets = datasets or {}
        if project_path:
            self.project_path = Path(project_path)
        if save_points_coverage is MISSING:
            self.save_points_coverage = (None, None)
        elif save_points_coverage is not None:
            self.save_points_coverage = save_points_coverage
        self._tree_paths = tree_paths or {}
        self.model_check_errors: list[ModelCheckError] = []
        self._items_that_exist = items_that_exist or {}

    @cached_property
    def _sim_uuid(self) -> str:
        """The UUID of the simulation."""
        return self._query.current_item_uuid()

    @cached_property
    def _sim_tree_node(self) -> TreeNode:
        """The tree node for the simulation."""
        tree = self._query.copy_project_tree()
        sim_node = tree_util.find_tree_node_by_uuid(tree, self._sim_uuid)
        return sim_node

    @cached_property
    def project_path(self) -> Path:
        """The project's path."""
        return Path(self._query.xms_project_path)

    def tree_path(self, item_uuid: str) -> str:
        """
        Get the tree path for an item.

        Args:
            item_uuid: The UUID of the item to get the tree path for.

        Returns:
            The path in the tree to the item with the given UUID.
        """
        if item_uuid not in self._tree_paths:
            tree = self._query.copy_project_tree()
            node = tree_util.find_tree_node_by_uuid(tree, item_uuid)
            path = tree_util.tree_path(node)
            self._tree_paths[item_uuid] = path
        return self._tree_paths[item_uuid]

    @cached_property
    def sim_data(self) -> SimulationData:
        """The data manager for the simulation."""
        sim_comp = self._query.item_with_uuid(self._sim_uuid, model_name='CMS-Flow', unique_name='Simulation_Component')
        sim_data = SimulationData(sim_comp.main_file)
        return sim_data

    @cached_property
    def ugrid(self) -> Optional[QuadtreeGrid2d | RectilinearGrid2d]:
        """
        The UGrid that is linked to the simulation.

        Will be None if no UGrid is linked.
        """
        ugrid, _projection, _native = self._ugrid_and_projections
        return ugrid

    @cached_property
    def ugrid_projection(self) -> Optional[Projection]:
        """
        The projection of the UGrid that is linked to the simulation.

        Will be None if no UGrid is linked.
        """
        _ugrid, projection, _native = self._ugrid_and_projections
        return projection

    @cached_property
    def native_ugrid_projection(self) -> Optional[Projection]:
        """
        The native projection of the UGrid that is linked to the simulation.

        Will be None if no UGrid is linked.
        """
        _ugrid, _projection, native = self._ugrid_and_projections
        return native

    @cached_property
    def _ugrid_and_projections(self) -> tuple[Optional[UGrid2d], Optional[Projection], Optional[Projection]]:
        """Helper to get the UGrid linked to the simulation and its projection and native projection."""
        ugrid_item: TreeNode = tree_util.descendants_of_type(
            self._sim_tree_node,
            xms_types=['TI_UGRID_PTR', 'TI_CGRID2D_PTR'],
            recurse=False,
            allow_pointers=True,
            only_first=True
        )
        if not ugrid_item:
            return None, None, None

        do_ugrid: DoUGrid = self._query.item_with_uuid(ugrid_item.uuid)
        co_grid: UGrid2d = read_grid_from_file(do_ugrid.cogrid_file)
        return co_grid, do_ugrid.projection, do_ugrid.native_projection

    @cached_property
    def activity_coverage(self) -> Optional[Coverage]:
        """
        The activity coverage linked to the simulation.

        Will be None if no activity coverage is linked.
        """
        activity_item = tree_util.descendants_of_type(
            self._sim_tree_node,
            xms_types=['TI_COVER_PTR'],
            recurse=False,
            allow_pointers=True,
            only_first=True,
            coverage_type='ACTIVITY_CLASSIFICATION'
        )
        if not activity_item:
            return None

        activity_cov = self._query.item_with_uuid(activity_item.uuid, generic_coverage=True)
        return activity_cov

    @cached_property
    def save_points_coverage(self) -> tuple[Optional[Coverage], Optional[SavePointsComponent]]:
        """The linked save points coverage."""
        sim_item = self._sim_tree_node
        save_points_item = tree_util.descendants_of_type(
            sim_item,
            xms_types=['TI_COVER_PTR'],
            allow_pointers=True,
            recurse=False,
            only_first=True,
            coverage_type='Save Points'
        )

        if not save_points_item:
            return None, None

        coverage = self._query.item_with_uuid(save_points_item.uuid)
        do_component = self._query.item_with_uuid(
            save_points_item.uuid, model_name='CMS-Flow', unique_name='Save_Points_Component'
        )
        if not coverage or not do_component:
            return None, None

        component = SavePointsComponent(do_component.main_file)
        self._query.load_component_ids(component, points=True, arcs=True, polygons=True)

        return coverage, component

    @cached_property
    def bc_coverage(self) -> tuple[Optional[Coverage], Optional['BCComponent']]:
        """
        The BC coverage linked to the simulation and its component.

        Will be (None, None) if no BC coverage is linked.
        """
        # BCComponent imports this for its dialog, so we have to import internally to break the dependency.
        from xms.cmsflow.components.bc_component import BCComponent

        bc_item = tree_util.descendants_of_type(
            self._sim_tree_node,
            xms_types=['TI_COVER_PTR'],
            recurse=False,
            allow_pointers=True,
            only_first=True,
            coverage_type='Boundary Conditions'
        )
        if not bc_item:
            return None, None

        bc_cov = self._query.item_with_uuid(bc_item.uuid)
        bc_do_comp = self._query.item_with_uuid(bc_item.uuid, model_name='CMS-Flow', unique_name='BC_Component')
        if not bc_cov or not bc_do_comp:  # pragma: nocover
            # There's a mismatch in the XML.
            raise AssertionError('Unexpected coverage type')

        bc_comp = BCComponent(bc_do_comp.main_file)
        self._query.load_component_ids(bc_comp, arcs=True)
        if not bc_comp.comp_to_xms:
            bc_comp.comp_to_xms = {bc_comp.cov_uuid: {}}
        if TargetType.arc not in bc_comp.comp_to_xms[bc_comp.cov_uuid]:
            bc_comp.comp_to_xms[bc_comp.cov_uuid][TargetType.arc] = {}

        return bc_cov, bc_comp

    @cached_property
    def bc_arc_attributes(self) -> dict[int, xr.Dataset]:
        """A dictionary mapping arc feature IDs to the arc's attributes."""
        coverage, component = self.bc_coverage
        data: BCData = component.data

        arc_component_id_map = component.comp_to_xms[component.cov_uuid][TargetType.arc]
        all_attrs = {}

        for component_id in arc_component_id_map:
            arc_attrs = data.arcs.loc[{'comp_id': component_id}]
            for feature_id in arc_component_id_map[component_id]:
                all_attrs[feature_id] = arc_attrs

        return all_attrs

    def get_ugrid(self, ugrid_uuid: str) -> UGrid2d:
        """
        Get a UGrid.

        The returned ugrid is cached. Getting the same UUID multiple times will receive the same instance. If no UGrid
        exists with the given UUID, then returns None.

        Args:
            ugrid_uuid: UUID of the UGrid to get.
        """
        if ugrid_uuid not in self._ugrids:
            do_grid: DoUGrid = self._query.item_with_uuid(ugrid_uuid)
            if do_grid is None:
                self._ugrids[ugrid_uuid] = None
            else:
                cogrid_file = do_grid.cogrid_file
                cogrid = read_grid_from_file(cogrid_file)
                self._ugrids[ugrid_uuid] = cogrid

        return self._ugrids[ugrid_uuid]

    def get_dataset(self, dataset_uuid: str) -> DatasetReader:
        """
        Get a dataset.

        The returned dataset is cached. Getting the same UUID multiple times will receive the same instance. If no
        dataset exists with the given UUID, then returns None.

        Args:
            dataset_uuid: UUID of the dataset to get.
        """
        if dataset_uuid not in self._datasets:
            reader = self._query.item_with_uuid(dataset_uuid)
            self._datasets[dataset_uuid] = reader

        return self._datasets[dataset_uuid]

    def item_exists(self, item_uuid: str) -> bool:
        """
        Check whether an item exists in the project explorer tree.

        Args:
            item_uuid: UUID of the item to look for.

        Returns:
            Whether the item exists.
        """
        if item_uuid == '' or item_uuid == '00000000-0000-0000-0000-000000000000':
            return False

        if item_uuid not in self._items_that_exist:
            tree = self._query.copy_project_tree()
            dataset_node = tree_util.find_tree_node_by_uuid(tree, item_uuid)
            self._items_that_exist[item_uuid] = dataset_node is not None

        return self._items_that_exist[item_uuid]

    def add_model_check_errors(self, errors: list[ModelCheckError]):
        """Send a list of model check errors to XMS."""
        if self._query is not None:
            self._query.add_model_check_errors(errors)
        else:
            self.model_check_errors.extend(errors)
