"""Class for managing interprocess communication with XMS."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.guipy.dialogs import treeitem_selector_datasets as ti_sel_ds

# 4. Local modules
from xms.rsm.file_io import util


class _CoverageGetter:
    """Class for getting coverage data."""
    def __init__(self, xms_data, coverage_type, unique_name):
        """Constructor.

        Args:
            xms_data (XmsData): The XMS interprocess communication object
            coverage_type (str): The type of coverage to retrieve
            unique_name (str): The unique name of the coverage
        """
        self._sim_tree_item = xms_data.sim_tree_item
        self._query = xms_data._xms_query
        self._coverage_type = coverage_type
        self._unique_name = unique_name
        self._component_class = self._get_component_class()
        self.coverages = []
        self.components = []
        self._get_coverages_and_components()

    def _get_component_class(self):
        """Get the component class based on the unique name."""
        # avoid circular import
        if self._unique_name == 'BasinComponent':
            from xms.rsm.components.basin_component import BasinComponent
            return BasinComponent
        elif self._unique_name == 'BcComponent':
            from xms.rsm.components.bc_component import BcComponent
            return BcComponent
        elif self._unique_name == 'CanalComponent':
            from xms.rsm.components.canal_component import CanalComponent
            return CanalComponent
        elif self._unique_name == 'CellMonitorComponent':
            from xms.rsm.components.cell_monitor_component import CellMonitorComponent
            return CellMonitorComponent
        elif self._unique_name == 'ImpoundmentComponent':
            from xms.rsm.components.impoundment_component import ImpoundmentComponent
            return ImpoundmentComponent
        elif self._unique_name == 'LakeComponent':
            from xms.rsm.components.lake_component import LakeComponent
            return LakeComponent
        elif self._unique_name == 'MeshDataComponent':
            from xms.rsm.components.mesh_ds_component import MeshDataComponent
            return MeshDataComponent
        elif self._unique_name == 'RuleCurveComponent':
            from xms.rsm.components.rule_curve_component import RuleCurveComponent
            return RuleCurveComponent
        elif self._unique_name == 'WcdComponent':
            from xms.rsm.components.wcd_component import WcdComponent
            return WcdComponent
        else:  # elif self._unique_name == 'WaterMoverComponent':
            from xms.rsm.components.water_mover_component import WaterMoverComponent
            return WaterMoverComponent

    def _get_coverages_and_components(self):
        """Returns the coverages."""
        cov_items = tree_util.descendants_of_type(
            tree_root=self._sim_tree_item,
            xms_types=['TI_COVER_PTR'],
            allow_pointers=True,
            only_first=False,
            coverage_type=self._coverage_type,
            model_name='RSM'
        )
        for cov_item in cov_items:
            cov = self._query.item_with_uuid(cov_item.uuid)
            do_comp = self._query.item_with_uuid(cov_item.uuid, unique_name=self._unique_name, model_name='RSM')
            py_comp = self._component_class(do_comp.main_file)
            self._query.load_component_ids(py_comp, points=True, arcs=True, polygons=True)
            self.coverages.append(cov)
            self.components.append(py_comp)

    def cov_list(self):
        """Returns a list of tuples of coverage and component."""
        cov_list = []
        for cov, comp in zip(self.coverages, self.components):
            cov_list.append((cov, comp))
        return cov_list

    def cov_and_comp(self):
        """Returns a tuple of coverage and component."""
        cov_and_comp = (None, None)
        if len(self.coverages) > 0:
            cov_and_comp = (self.coverages[0], self.components[0])
        return cov_and_comp


class XmsData:
    """Class for managing interprocess communication with XMS."""
    def __init__(self, query=None, sim_component=None):
        """Constructor.

        Args:
            query (Query): The XMS interprocess communication object
            sim_component (SimComponent): the simulation component
        """
        self.sim_global_parameters = None
        self._xms_query = query
        self._logger = util.get_logger()
        self._sim_comp = sim_component
        self._sim_item = None
        self._sim_tree_item = None
        self._sim_uuid = None
        self._sim_name = None
        self._do_ugrid = None
        self._xmugrid = None
        self._cogrid = None
        self._ugrid_item = None
        self._extractor = None
        self._ug_polygon_snap = None
        self._basin_covs = []
        self._bc_covs = []
        self._canal_cov = None
        self._canal_comp = None
        self._cell_monitor_covs = []
        self._impoundment_covs = []
        self._lake_covs = []
        self._mesh_datasets = []
        self._rule_curve_covs = []
        self._wcd_covs = []
        self._water_mover_covs = []
        self._waterbody_start_id = 100_000

    @property
    def _query(self):
        """Returns the xmsapi interprocess communication object.

        Notes:
            If constructed in a component runner script, Query should be passed in on construction. If constructed in
            an export or import script, creation of the Query can happen here. Only one Query/XmsData object should be
            constructed per script.
        """
        if self._xms_query is None:
            self._xms_query = Query()
        return self._xms_query

    def copy_project_tree(self):
        """Returns a copy of the project explorer tree."""
        return self._query.copy_project_tree()

    @property
    def display_wkt(self):
        """Returns the display projection well known text."""
        if self._xms_query:
            return self._xms_query.display_projection.well_known_text
        return ''

    def _get_sim_data(self):
        """Get the simulation data from the XMS project tree.

        Raises:
            RuntimeError: If the simulation data cannot be retrieved.
        """
        if self._sim_item is None:
            if self._sim_comp:
                self._sim_item = self._query.parent_item()
            else:
                self._sim_item = self._query.current_item()
            if self._sim_item:
                self._sim_uuid = self._sim_item.uuid
                self._sim_name = self._sim_item.name
                if self._sim_comp is None:
                    do_comp = self._query.item_with_uuid(self._sim_uuid, model_name='RSM', unique_name='SimComponent')
                    from xms.rsm.components.sim_component import SimComponent
                    self._sim_comp = SimComponent(do_comp.main_file)
                self._sim_tree_item = self.get_tree_item(self._sim_uuid)

    @property
    def sim_component(self):
        """Returns the simulation's UUID (not it's component's UUID)."""
        if self._sim_comp is None:
            self._get_sim_data()
        return self._sim_comp

    @property
    def sim_tree_item(self):
        """Returns the simulation's UUID (not it's component's UUID)."""
        if self._sim_tree_item is None:
            self._get_sim_data()
        return self._sim_tree_item

    @property
    def sim_uuid(self):
        """Returns the simulation's UUID (not it's component's UUID)."""
        if self._sim_uuid is None:
            self._get_sim_data()
        return self._sim_uuid

    @property
    def sim_name(self):
        """Returns the name of simulation."""
        if self._sim_name is None:
            self._get_sim_data()
        return self._sim_name

    @property
    def do_ugrid(self):
        """Retrieves the simulation's data_objects UGrid."""
        if self._do_ugrid is None:
            self._logger.info('Retrieving UGrid from SMS...')
            uuid = self.ugrid_item.uuid if self.ugrid_item else ''
            self._do_ugrid = self._query.item_with_uuid(uuid)
            if self._do_ugrid is not None:
                self._cogrid = read_grid_from_file(self._do_ugrid.cogrid_file)
        return self._do_ugrid

    @property
    def xmugrid(self):
        """Retrieves the simulation's data_objects XmUGrid."""
        if self._xmugrid is None and self.cogrid:
            self._xmugrid = self.cogrid.ugrid
        return self._xmugrid

    @property
    def cogrid(self):
        """Retrieves the simulation's domain CoGrid."""
        if self._cogrid is None:
            _ = self.do_ugrid
        return self._cogrid

    @property
    def ugrid_item(self):
        """Returns the linked UGrid tree item, or None if there is none."""
        if self._ugrid_item is None and self.sim_tree_item is not None:
            ugrid_ptr = tree_util.descendants_of_type(
                tree_root=self.sim_tree_item, only_first=True, xms_types=['TI_UGRID_PTR'], allow_pointers=True
            )
            if ugrid_ptr is not None:
                self._ugrid_item = self.get_tree_item(ugrid_ptr.uuid)
        return self._ugrid_item

    @property
    def ugrid_extractor(self):
        """Returns the UGrid extractor for the simulation."""
        if self._extractor is not None:
            return self._extractor
        if self.xmugrid is None:
            return None
        self._extractor = _extractor_from_ugrid(self.xmugrid)
        return self._extractor

    @property
    def waterbody_start_id(self):
        """Returns the starting id for water bodies."""
        if self.xmugrid:
            cnt = (self.xmugrid.cell_count // 100_000) + 1
            self._waterbody_start_id = 100_000 * cnt
        return self._waterbody_start_id

    @property
    def basin_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._basin_covs:
            self._retrieve_basin_coverages()
        return self._basin_covs

    @property
    def bc_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._bc_covs:
            self._retrieve_bc_coverages()
        return self._bc_covs

    @property
    def canal_coverage_and_component(self):
        """Returns the canal coverage."""
        if self._canal_cov is None:
            self._retrieve_canal_data()
        return self._canal_cov, self._canal_comp

    @property
    def cell_monitor_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._cell_monitor_covs:
            self._retrieve_cell_monitor_coverages()
        return self._cell_monitor_covs

    @property
    def impoundment_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._impoundment_covs:
            self._retrieve_impoundment_coverages()
        return self._impoundment_covs

    @property
    def lake_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._lake_covs:
            self._retrieve_lake_coverages()
        return self._lake_covs

    @property
    def mesh_dataset_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._mesh_datasets:
            self._retrieve_mesh_datasets_coverages()
        return self._mesh_datasets

    @property
    def rule_curve_coverages(self):
        """Returns the rule curve coverages."""
        if not self._rule_curve_covs:
            self._retrieve_rule_curve_coverages()
        return self._rule_curve_covs

    @property
    def wcd_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._wcd_covs:
            self._retrieve_wcd_coverages()
        return self._wcd_covs

    @property
    def water_mover_coverages(self):
        """Returns the mesh dataset coverages."""
        if not self._water_mover_covs:
            self._retrieve_water_mover_coverages()
        return self._water_mover_covs

    def get_tree_item(self, item_uuid):
        """Get a project explorer item from UUID.

        Args:
            item_uuid (str): UUID of the item to retrieve

        Returns:
            (TreeNode): The tree item with specified UUID or None if not found
        """
        return tree_util.find_tree_node_by_uuid(self._query.project_tree, item_uuid)

    def remove_existing_snapped_components(self):
        """Remove any existing snapped components from the simulation."""
        items = tree_util.descendants_of_type(self.sim_tree_item, xms_types=['TI_COMPONENT'])
        for item in items:
            self._query.delete_item(item.uuid)

    def add_component_to_xms(self, do_component, actions):
        """Add a component to XMS.

        Args:
            do_component (xms.data_objects.parameters.Component): The component to add
            actions (list): The actions to perform
        """
        self._query.add_component(do_component=do_component, actions=actions)
        self._query.link_item(taker_uuid=self.sim_uuid, taken_uuid=do_component.uuid)

    def dset_values_from_uuid_ts(self, uuid_ts):
        """Get a dset given the uuid.

        Args:
            uuid_ts (str): uuid of the dataset

        Returns:
            (list): dataset values
        """
        uuid, ts_idx = ti_sel_ds.uuid_and_time_step_index_from_string(uuid_ts)
        dset = self._query.item_with_uuid(uuid)
        vals = []
        if dset is not None:
            vals = dset.values[ts_idx]
        return vals

    def dataset_is_child_of_simluation_ugrid(self, dataset_uuid):
        """Check if the dataset is a child of the simulation's UGrid.

        Args:
            dataset_uuid (str): The dataset uuid to check

        Returns:
            (bool): True if the dataset is a child of the simulation's UGrid, False otherwise
        """
        ds_item = self.get_tree_item(dataset_uuid)
        ds_ugrid = tree_util.ancestor_of_type(ds_item, xms_types=['TI_UGRID_SMS'])
        if ds_ugrid == self.ugrid_item:
            return True
        return False

    def _retrieve_basin_coverages(self):
        """Query for the simulation's lake coverages and components."""
        self._logger.info('Retrieving basin coverages from SMS...')
        cov_type, unique_name = 'Basins', 'BasinComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._basin_covs = getter.cov_list()

    def _retrieve_bc_coverages(self):
        """Query for the simulation's bc coverages and components."""
        self._logger.info('Retrieving BC coverages from SMS...')
        cov_type, unique_name = 'Boundary Conditions', 'BcComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._bc_covs = getter.cov_list()

    def _retrieve_canal_data(self):
        """Query for the simulation's Canal coverage and components."""
        self._logger.info('Retrieving Canal coverage from SMS...')
        getter = _CoverageGetter(self, 'Canals', 'CanalComponent')
        self._canal_cov, self._canal_comp = getter.cov_and_comp()

    def _retrieve_cell_monitor_coverages(self):
        """Query for the simulation's cell monitor coverage and components."""
        self._logger.info('Retrieving Cell Segment Monitor coverage from SMS...')
        cov_type, unique_name = 'Cell Monitor', 'CellMonitorComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._cell_monitor_covs = getter.cov_list()

    def _retrieve_impoundment_coverages(self):
        """Query for the simulation's lake coverages and components."""
        self._logger.info('Retrieving impoundment coverages from SMS...')
        cov_type, unique_name = 'Impoundments', 'ImpoundmentComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._impoundment_covs = getter.cov_list()

    def _retrieve_lake_coverages(self):
        """Query for the simulation's lake coverages and components."""
        self._logger.info('Retrieving lake coverages from SMS...')
        cov_type, unique_name = 'Lakes', 'LakeComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._lake_covs = getter.cov_list()

    def _retrieve_mesh_datasets_coverages(self):
        """Query for the simulation's mesh_dataset coverages and components."""
        self._logger.info('Retrieving Mesh Dataset coverages from SMS...')
        cov_type, unique_name = 'Mesh datasets', 'MeshDataComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._mesh_datasets = getter.cov_list()

    def _retrieve_rule_curve_coverages(self):
        """Query for the simulation's rule curve coverages and components."""
        self._logger.info('Retrieving rule curve coverages from SMS...')
        cov_type, unique_name = 'Rule Curves', 'RuleCurveComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._rule_curve_covs = getter.cov_list()

    def _retrieve_wcd_coverages(self):
        """Query for the simulation's wcd coverages and components."""
        self._logger.info('Retrieving WCD coverages from SMS...')
        cov_type, unique_name = 'Water Control Districts', 'WcdComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._wcd_covs = getter.cov_list()

    def _retrieve_water_mover_coverages(self):
        """Query for the simulation's lake coverages and components."""
        self._logger.info('Retrieving water mover coverages from SMS...')
        cov_type, unique_name = 'Water Mover', 'WaterMoverComponent'
        getter = _CoverageGetter(self, cov_type, unique_name)
        self._water_mover_covs = getter.cov_list()


def _extractor_from_ugrid(ugrid):
    """Create an extractor from a UGrid.

    Args:
        ugrid (UGrid): The UGrid to create the extractor from

    Returns:
        (UGrid2dDataExtractor): The extractor for the UGrid
    """
    from xms.extractor.ugrid_2d_data_extractor import UGrid2dDataExtractor
    ext = UGrid2dDataExtractor(ugrid)
    n = ugrid.point_count
    ext.set_grid_point_scalars([1.0] * n, [1] * n, 'points')
    return ext
