"""Class for managing interprocess communication with XMS."""
# 1. Standard python modules
import logging
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import Coverage

# 4. Local modules
from xms.tuflowfv.components.bc_component import BcComponent
from xms.tuflowfv.components.material_component import MaterialComponent
from xms.tuflowfv.components.structure_component import StructureComponent
from xms.tuflowfv.data import sim_data as smd


class XmsData:
    """Class for managing interprocess communication with XMS."""

    def __init__(self, query=None, at_sim=True):
        """Constructor.

        Args:
            query (Query): The XMS interprocess communication object
            at_sim (bool): True if the Context is at the simulation level, False if it is directly below the simulation.
        """
        self._query = query
        self._logger = logging.getLogger('xms.tuflowfv')
        self._at_sim = at_sim
        self._sim_uuid = None
        self._sim_name = None
        self._sim_data = None
        self._do_ugrid = None
        self._xmugrid = None
        self._cogrid = None
        self._ugrid_item = None
        self._structure_cov = None
        self._structure_component = None
        self._bc_covs = None  # [Coverage]
        self._bc_comps = None  # [python Component]
        self._mat_covs = None  # [Coverage]
        self._mat_comps = None  # [python Component]
        self._output_points_covs = {}  # {cov_uuid: Coverage}
        self._output_points_components = {}  # {cov_uuid: Component}
        self._zlines = None  # [(line_filename/Coverage, [point_filenames/Coverages])] - Can be shapefiles or Coverages
        self._wind_covs = None

    def _retrieve_bc_data(self):
        """Query for the simulation's BC coverages and components."""
        self._logger.info('Retrieving Boundary Condition coverages from SMS...')
        self._bc_covs = []
        self._bc_comps = []
        sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, self.sim_uuid)
        bc_items = tree_util.descendants_of_type(tree_root=sim_item, xms_types=['TI_COVER_PTR'],
                                                 allow_pointers=True, only_first=False,
                                                 coverage_type='Boundary Conditions', model_name='TUFLOWFV')
        for bc_item in bc_items:
            self._bc_covs.append(self.query.item_with_uuid(bc_item.uuid))
            do_comp = self.query.item_with_uuid(bc_item.uuid, unique_name='BcComponent', model_name='TUFLOWFV')
            py_comp = BcComponent(do_comp.main_file)
            self.query.load_component_ids(py_comp, points=True, arcs=True, polygons=True)
            self._bc_comps.append(py_comp)

    def _retrieve_material_data(self):
        """Query for the simulation's BC coverages and components."""
        self._logger.info('Retrieving Material coverages from SMS...')
        self._mat_covs = []
        self._mat_comps = []
        sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, self.sim_uuid)
        mat_items = tree_util.descendants_of_type(tree_root=sim_item, xms_types=['TI_COVER_PTR'],
                                                  allow_pointers=True, only_first=False,
                                                  coverage_type='Materials', model_name='TUFLOWFV')
        for mat_item in mat_items:
            self._mat_covs.append(self.query.item_with_uuid(mat_item.uuid))
            do_comp = self.query.item_with_uuid(mat_item.uuid, unique_name='MaterialComponent',
                                                model_name='TUFLOWFV')
            py_comp = MaterialComponent(do_comp.main_file)
            self.query.load_component_ids(py_comp, polygons=True)
            self._mat_comps.append(py_comp)

    def _retrieve_zlines(self):
        """Retrieve all the Z line and point shapefiles and/or Coverages from XMS."""
        self._zlines = []
        z_lines = self.sim_data.z_modifications.where(
            self.sim_data.z_modifications.type == smd.ELEV_TYPE_ZLINE, drop=True
        )
        index_name = next(iter(z_lines.sizes.keys()))
        indices = z_lines[index_name].data.tolist()
        point_variables = smd.get_point_variable_names()
        for index in indices:
            zline = z_lines.where(z_lines[index_name] == index, drop=True)
            line_z = self.query.item_with_uuid(zline.uuid.item())
            if not isinstance(line_z, Coverage) and (not line_z or not os.path.isfile(line_z)):
                self._logger.warning('Unable to retrieve selected Z line shapefile/coverage. Ensure that all '
                                     'referenced inputs still exist')
                continue
            point_zs = []
            for point_variable in point_variables:
                if zline[point_variable].item():
                    point_z = self.query.item_with_uuid(zline[point_variable].item())
                    if not isinstance(point_z, Coverage) and (not point_z or not os.path.isfile(point_z)):
                        self._logger.warning('Unable to retrieve selected Z point shapefile/coverage. Ensure that all '
                                             'referenced inputs still exist')
                    else:
                        point_zs.append(point_z)
            self._zlines.append((line_z, point_zs))

    def _retrieve_wind_covs(self):
        """Query for the simulation's linked Holland wind boundary coverages."""
        self._logger.info('Retrieving Holland wind boundary coverages from SMS...')
        self._wind_covs = []
        sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, self.sim_uuid)
        wind_items = tree_util.descendants_of_type(tree_root=sim_item, xms_types=['TI_COVER_PTR'],
                                                   allow_pointers=True, only_first=False, coverage_type='WIND')
        for wind_item in wind_items:
            self._wind_covs.append(self.query.item_with_uuid(wind_item.uuid, generic_coverage=True))

    def _retrieve_structure_data(self):
        """Retrieve the structure coverage and component if it exists."""
        sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, self.sim_uuid)
        structure_item = tree_util.descendants_of_type(tree_root=sim_item, xms_types=['TI_COVER_PTR'],
                                                       allow_pointers=True, only_first=True, coverage_type='Structures',
                                                       model_name='TUFLOWFV')
        if structure_item is not None:
            self._structure_cov = self.query.item_with_uuid(structure_item.uuid)
            do_comp = self.query.item_with_uuid(structure_item.uuid, unique_name='StructureComponent',
                                                model_name='TUFLOWFV')
            py_comp = StructureComponent(do_comp.main_file)
            self.query.load_component_ids(py_comp, arcs=True, polygons=True)
            self._structure_component = py_comp

    @property
    def query(self):
        """Returns the xmsapi interprocess communication object.

        Notes:
            If constructed in a component runner script, Query should be passed in on construction. If constructed in
            an export or import script, creation of the Query can happen here. Only one Query/XmsData object should be
            constructed per script.
        """
        if self._query is None:
            self._query = Query()
        return self._query

    def set_sim(self, sim_uuid, sim_name):
        """Set the simulation that data will be retrieved for.

        Args:
            sim_uuid (str): The UUID of the simulation
            sim_name (str): The name of the simulation
        """
        self._sim_uuid = sim_uuid
        self._sim_name = sim_name

    @property
    def sim_uuid(self):
        """Returns the simulation's UUID (not it's component's UUID)."""
        if self._sim_uuid is None:
            self._sim_uuid = self.query.current_item_uuid() if self._at_sim else self.query.parent_item_uuid()
        return self._sim_uuid

    @property
    def sim_name(self):
        """Returns the name of simulation."""
        if self._sim_name is None:
            sim_item = self.query.current_item() if self._at_sim else self.query.parent_item()
            if sim_item and sim_item.name:
                self._sim_name = sim_item.name
            else:
                self._sim_name = 'Sim'
        return self._sim_name

    @property
    def sim_data(self):
        """Returns the simulation component's SimData."""
        if self._sim_data is None:
            self._logger.info('Retrieving simulation data from SMS...')
            do_comp = self.query.item_with_uuid(self.sim_uuid, model_name='TUFLOWFV', unique_name='SimComponent')
            if not do_comp:
                raise RuntimeError('Unable to retrieve simulation data from SMS.')
            self._sim_data = smd.SimData(do_comp.main_file)
        return self._sim_data

    @property
    def do_ugrid(self):
        """Retrieves the simulation's data_objects UGrid."""
        if self._do_ugrid is None:
            self._logger.info('Retrieving domain mesh from SMS...')
            self._do_ugrid = self.query.item_with_uuid(self.sim_data.info.attrs['domain_uuid'])
        return self._do_ugrid

    @property
    def xmugrid(self):
        """Retrieves the simulation's data_objects XmUGrid."""
        if self._xmugrid is None:  # Cache this because it is easy to accidentally recreate it in loops.
            self._xmugrid = self.cogrid.ugrid
        return self._xmugrid

    @property
    def cogrid(self):
        """Retrieves the simulation's domain CoGrid."""
        if self._cogrid is None:
            self._cogrid = read_grid_from_file(self.do_ugrid.cogrid_file)
        return self._cogrid

    @property
    def ugrid_item(self):
        """Returns the linked UGrid tree item, or None if there is none."""
        if self._ugrid_item is None:
            self._ugrid_item = tree_util.find_tree_node_by_uuid(self.query.project_tree,
                                                                self.sim_data.info.attrs['domain_uuid'])
            if not self._ugrid_item:
                raise RuntimeError('Unable to retrieve linked mesh geometry.')
        return self._ugrid_item

    @property
    def bc_covs(self):
        """Retrieves a list of the simulation's BC coverages in tree order (parallel with bc_comps)."""
        if self._bc_covs is None:
            self._retrieve_bc_data()
        return self._bc_covs

    @property
    def bc_comps(self):
        """Retrieves a list of the simulation's BC coverage components (parallel with bc_covs)."""
        if self._bc_comps is None:
            self._retrieve_bc_data()
        return self._bc_comps

    @property
    def mat_covs(self):
        """Retrieves a list of the simulation's Material coverages in tree order (parallel with mat_comps)."""
        if self._mat_covs is None:
            self._retrieve_material_data()
        return self._mat_covs

    @property
    def mat_comps(self):
        """Retrieves a list of the simulation's Material coverage components (parallel with mat_covs)."""
        if self._mat_comps is None:
            self._retrieve_material_data()
        return self._mat_comps

    @property
    def zlines(self):
        """Retrieve filenames of all the Z line and point shapefiles."""
        if self._zlines is None:
            self._retrieve_zlines()
        return self._zlines

    @property
    def wind_covs(self):
        """Retrieve WindCoverage dumps for all the linked Holland wind boundaries."""
        if self._wind_covs is None:
            self._retrieve_wind_covs()
        return self._wind_covs

    @property
    def structure_cov(self):
        """Retrieve the structures coverage (should only be one)."""
        if self._structure_cov is None:
            self._retrieve_structure_data()
        return self._structure_cov

    @property
    def structure_comp(self):
        """Retrieves a structures coverage component."""
        if self._structure_component is None:
            self._retrieve_structure_data()
        return self._structure_component

    def get_tree_item(self, item_uuid):
        """Get a project explorer item from UUID.

        Args:
            item_uuid (str): UUID of the item to retrieve

        Returns:
            TreeNode: The tree item with specified UUID or None if not found
        """
        return tree_util.find_tree_node_by_uuid(self.query.project_tree, item_uuid)

    def get_output_points_cov(self, cov_uuid):
        """Retrieves a points output block's coverage geometry (may be any coverage type).

        Args:
            cov_uuid (str): UUID of the feature coverage geometry

        Returns:
            data_objects.parameters.Coverage: The output points block's data_objects Coverage geometry, None if unable
                to retrieve
        """
        do_cov = self._output_points_covs.setdefault(cov_uuid, self.query.item_with_uuid(cov_uuid))
        if not do_cov:
            self._logger.error('Unable to retrieve referenced output points feature coverage geometry. Ensure all '
                               'specified inputs still exist in SMS.')
        return do_cov

    def get_output_points_component(self, cov_uuid):
        """Retrieves a points output block's coverage component.

        Args:
            cov_uuid (str): UUID of the feature coverage geometry

        Returns:
            data_objects.parameters.Component: The output points block's data_objects Component, None if unable to
                retrieve. Note that the output coverage does not need to be a TUFLOWFV Output Points type.
        """
        return self._output_points_components.setdefault(
            cov_uuid,
            self.query.item_with_uuid(item_uuid=cov_uuid, model_name='TUFLOWFV', unique_name='OutputPointsComponent')
        )
