"""Class for managing interprocess communication with XMS."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
from typing import Optional

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.datasets.dataset_reader import DatasetReader

# 4. Local modules
from xms.funwave.components.station_component import StationComponent
from xms.funwave.data.model import get_model
from xms.funwave.data.sim_data import SimData


LINKABLE_GRIDS = ['TI_CGRID2D_PTR']


class XmsData:
    """Class for managing interprocess communication with XMS."""

    def __init__(self, query=None):
        """Constructor.

        Args:
            query (:obj:`Query`): The XMS interprocess communication object
        """
        self._query = query
        self._logger = logging.getLogger('xms.funwave')
        self._sim_uuid = None
        self._sim_item = None
        self._sim_name = None
        self._sim_data = None
        self._do_ugrid = None
        self._xmugrid = None  # read this only once
        self._cogrid = None  # read this only once
        self._ugrid_item = None
        self._output_points_covs = {}  # {cov_uuid: Coverage}
        self._so_covs = None
        self._so_comps = None

    def retrieve_output_station_data(self):
        """Query for the simulation's Output Station coverages and components."""
        self._logger.info('Retrieving Output Station coverages from SMS...')
        self._so_covs = []
        self._so_comps = []
        sim_item = tree_util.find_tree_node_by_uuid(self._query.project_tree, self.sim_uuid)
        output_station_items = tree_util.descendants_of_type(tree_root=sim_item, xms_types=['TI_COVER_PTR'],
                                                             allow_pointers=True, only_first=False,
                                                             coverage_type='Output Stations', model_name='FUNWAVE')
        for cov_item in output_station_items:
            self._so_covs.append(self._query.item_with_uuid(cov_item.uuid))
            do_comp = self._query.item_with_uuid(cov_item.uuid, unique_name='StationComponent', model_name='FUNWAVE')
            py_comp = StationComponent(do_comp.main_file)
            self._query.load_component_ids(py_comp, points=True, arcs=True)
            self._so_comps.append(py_comp)
        return self._so_covs

    @property
    def query(self):
        """Returns the xmsapi interprocess communication object.

        Notes:
            If constructed in a component runner script, Query should be passed in on construction. If constructed in
            an export or import script, creation of the Query can happen here. Only one Query/XmsData object should be
            constructed per script.
        """
        if self._query is None:
            self._query = Query()
        return self._query

    def set_sim(self, sim_uuid, sim_name):
        """Set the simulation that data will be retrieved for.

        Args:
            sim_uuid (:obj:`str`): The UUID of the simulation
            sim_name (:obj:`str`): The name of the simulation
        """
        self._sim_uuid = sim_uuid
        self._sim_name = sim_name

    @property
    def sim_uuid(self):
        """Returns the simulation's UUID (not it's component's UUID)."""
        if self._sim_uuid is None:
            self._sim_uuid = self.query.current_item_uuid()
        return self._sim_uuid

    @property
    def sim_name(self):
        """Returns the name of simulation."""
        if self._sim_name is None:
            sim_item = self.query.current_item()
            if sim_item and sim_item.name:
                self._sim_name = sim_item.name
            else:
                self._sim_name = 'Sim'
        return self._sim_name

    @property
    def sim_item(self):
        """Returns the simulation's TreeNode."""
        if self._sim_item is None:
            sim_uuid = self.query.current_item_uuid()
            sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, sim_uuid)
            if sim_item is None:
                sim_uuid = self.query.parent_item_uuid()
                sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, sim_uuid)
            self._sim_item = sim_item
        return self._sim_item

    # @property
    # def sim_data(self):
    #     """Returns the simulation component's SimData."""
    #     if self._sim_data is None:
    #         self._logger.info('Retrieving simulation data from SMS...')
    #         do_comp = self.query.item_with_uuid(self.sim_uuid, model_name='FUNWAVE', unique_name='SimComponent')
    #         if not do_comp:
    #             raise RuntimeError('Unable to retrieve simulation data from SMS.')
    #         self._sim_data = smd.SimData(do_comp.main_file)
    #     return self._sim_data
    @property
    def sim_data(self) -> SimData:
        """Returns the simulation component's SimData."""
        if self._sim_data is None:
            self._logger.info('Retrieving simulation data from SMS...')
            do_comp = self.query.item_with_uuid(self.sim_uuid, model_name='FUNWAVE', unique_name='SimComponent')
            if not do_comp:
                raise RuntimeError('Unable to retrieve simulation data from SMS.')
            self._sim_data = SimData(do_comp.main_file)
        return self._sim_data

    @property
    def sim_data_model_control(self):
        """Returns the model control data."""
        model = get_model()
        values = self.sim_data.global_values
        model.global_parameters.restore_values(values)
        return model.global_parameters

    @property
    def do_ugrid(self):
        """Retrieves the simulation's data_objects UGrid."""
        if self._do_ugrid is None:
            self._logger.info('Retrieving domain mesh from SMS...')
            tree_node = tree_util.descendants_of_type(
                self.sim_item, allow_pointers=True, xms_types=LINKABLE_GRIDS, only_first=True
            )
            if tree_node is None:
                return None
            self._do_ugrid = self.query.item_with_uuid(tree_node.uuid)
            if not self._do_ugrid:
                raise RuntimeError('Unable to retrieve domain mesh from SMS.')
        return self._do_ugrid

    # @property
    # def xmugrid(self):
    #     """Retrieves the simulation's data_objects XmUGrid."""
    #     if self._xmugrid is None:
    #         self._xmugrid = self.cogrid.ugrid
    #     return self._xmugrid

    @property
    def cogrid(self):
        """Retrieves the simulation's domain CoGrid."""
        if self._cogrid is None:
            do_ugrid = self.do_ugrid
            if do_ugrid is None:
                return None
            self._cogrid = read_grid_from_file(do_ugrid.cogrid_file)
        return self._cogrid

    @property
    def ugrid_item(self):
        """Returns the linked UGrid tree item, or None if there is none."""
        if self._ugrid_item is None:
            self._ugrid_item = tree_util.find_tree_node_by_uuid(self.query.project_tree,
                                                                self.sim_data.info.attrs['domain_uuid'])
            if not self._ugrid_item:
                raise RuntimeError('Unable to retrieve linked mesh geometry.')
        return self._ugrid_item

    def get_dataset_reader(self, dataset_uuid: str) -> Optional[DatasetReader]:
        """Get a dataset reader from XMS.

        Args:
            dataset_uuid: The dataset UUID.

        Returns:
            The DatasetReader.
        """
        return self.query.item_with_uuid(dataset_uuid)

    def get_tree_item(self, item_uuid):
        """Get a project explorer item from UUID.

        Args:
            item_uuid (:obj:`str`): UUID of the item to retrieve

        Returns:
            (:obj:`TreeNode`): The tree item with specified UUID or None if not found
        """
        return tree_util.find_tree_node_by_uuid(self._query.project_tree, item_uuid)
