"""Module for the MapUpwindSolverCoverageThread."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['MapUpwindSolverCoverageThread']

# 1. Standard Python modules
from pathlib import Path
from typing import Optional

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.components.display.display_options_helper import DisplayOptionsHelper as ComponentsDisplayOptionsHelper
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import Component, Coverage, Polygon, UGrid
from xms.gmi.components.utils import invert_comp_to_xms
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.feedback_thread import FeedbackThread

# 4. Local modules
from xms.schism.components.mapped_upwind_solver_coverage_component import MappedUpwindSolverCoverageComponent
from xms.schism.components.upwind_solver_coverage_component import UpwindSolverCoverageComponent
from xms.schism.data.mapped_upwind_solver_coverage_data import MappedUpwindSolverCoverageData
from xms.schism.data.sim_data import SimData
from xms.schism.external.crc import compute_crc
from xms.schism.external.polygon_mapper import get_cell_id_to_material_id_map
from xms.schism.feedback.display_options_helper import DisplayOptionsHelper, load_component_id_map


class MapUpwindSolverCoverageThread(FeedbackThread):
    """Feedback thread for mapping an upwind solver coverage."""
    def __init__(self, query: Query, coverage_uuid: str):
        """
        Initialize the thread.

        Args:
            query: Interprocess communication object.
            coverage_uuid: UUID of the coverage being mapped.
        """
        super().__init__(query)
        self.display_text |= {
            'title': 'Mapping TVD or WENO coverage',
        }
        self._unmapped_coverage_uuid = coverage_uuid
        self._sim_data: Optional[SimData] = None
        self._sim_uuid: Optional[str] = None
        self._domain: Optional[UGrid] = None
        self._domain_hash: str = ''
        self._mapping: Optional[np.array] = None
        self._mapped_component: Optional[MappedUpwindSolverCoverageComponent] = None
        self._mapped_component_directory: Optional[Path] = None
        self._unmapped_polygons: Optional[list[list[Polygon]]] = None
        self._unmapped_coverage: Optional[Coverage] = None
        self._unmapped_component: Optional[UpwindSolverCoverageComponent] = None
        self._polygon_id_to_flag: Optional[dict[int, int]] = None

    def _unlink_components(self):
        """Unlink the unmapped coverage and any existing mapped ones."""
        self._query.unlink_item(self._sim_uuid, self._unmapped_coverage_uuid)

        project_tree = self._query.project_tree
        sim_node = tree_util.find_tree_node_by_uuid(project_tree, self._sim_uuid)
        existing_mapped_components = tree_util.descendants_of_type(sim_node, unique_name='MappedUpwindSolverCoverage')
        for component in existing_mapped_components:
            self._query.unlink_item(self._sim_uuid, component.uuid)

    def _get_sim_data(self):
        """Initialize `self._sim_data`."""
        self._log.info('Reading simulation...')
        main_file = self._query.current_item().main_file
        self._sim_data = SimData(main_file)

    def _get_grid(self):
        """Get the domain from XMS."""
        self._log.info('Reading domain...')
        project_tree = self._query.project_tree
        sim_node = tree_util.find_tree_node_by_uuid(project_tree, self._query.parent_item_uuid())
        domain_node = tree_util.descendants_of_type(
            sim_node, xms_types=['TI_MESH2D_PTR'], only_first=True, allow_pointers=True
        )
        domain_uuid = domain_node.uuid
        domain = self._query.item_with_uuid(domain_uuid)
        self._domain = read_grid_from_file(domain.cogrid_file)
        self._domain_hash = compute_crc(domain.cogrid_file)

    def _get_coverage(self):
        """
        Get the existing coverage and its component.

        The component will have its component ID map loaded.
        """
        self._log.info('Retrieving coverage...')
        self._unmapped_coverage = self._query.item_with_uuid(self._unmapped_coverage_uuid)
        do_comp = self._query.item_with_uuid(
            self._unmapped_coverage_uuid, model_name='SCHISM', unique_name='UpwindSolverCoverageComponent'
        )
        self._unmapped_component = UpwindSolverCoverageComponent(do_comp.main_file)
        load_component_id_map(self._query, self._unmapped_component)

    def _get_polygons(self):
        """Initialize `self._unmapped_polygons`."""
        unmapped_polygons = self._unmapped_coverage.polygons
        [uuid] = self._unmapped_component.comp_to_xms.keys()
        comp_to_xms = self._unmapped_component.comp_to_xms[uuid][TargetType.polygon]
        xms_to_comp = invert_comp_to_xms(comp_to_xms)
        self._unmapped_polygons = [[], []]

        for polygon in unmapped_polygons:
            comp_id = xms_to_comp[polygon.id]
            feature_type = self._unmapped_component.data.feature_type(TargetType.polygon, comp_id)
            if feature_type == 'upwind':
                self._unmapped_polygons[0].append(polygon)
            else:
                self._unmapped_polygons[1].append(polygon)

    def _build_mapping(self):
        """Initialize `self._mapping` with a mapping from UGrid cell index to feature index."""
        mapping = get_cell_id_to_material_id_map(self._unmapped_polygons, self._domain, default_material=1)
        self._mapping = mapping

    def _make_mapped_component(self):
        """Create an empty mapped component."""
        self._log.info('Creating mapped coverage...')
        data = MappedUpwindSolverCoverageData()
        data.close()
        self._mapped_component = MappedUpwindSolverCoverageComponent(data.main_file)

        self._mapped_component_directory = Path(self._mapped_component.main_file).parent
        self._sim_data.mapped_coverage_uuid = self._mapped_component.uuid
        self._sim_data.commit()

    def _add_data(self):
        """Put the arc data in self._mapped_component."""
        self._log.info('Initializing mapped coverage...')
        self._mapped_component.data.solver = self._mapping
        self._mapped_component.data.domain_hash = self._domain_hash
        self._mapped_component.data.commit()

    def _write_display_options(self):
        """Write the component's new display options."""
        with ComponentsDisplayOptionsHelper(self._unmapped_component.main_file) as helper:
            helper.apply_to_drawing(self._mapped_component.main_file)

    def _write_locations(self):
        """Write out the nodes to the display options file."""
        helper = DisplayOptionsHelper(self._mapped_component.main_file)
        helper.add_feature_types(TargetType.polygon, ['Upwind', 'Higher order'])
        helper.draw_ugrid(self._domain.ugrid, self._mapping)

    def _link_component(self):
        """Link the new component to the simulation."""
        self._log.info('Sending data...')
        name = f'{self._unmapped_coverage.name} (applied)'
        do_component = Component(
            main_file=self._mapped_component.main_file,
            name=name,
            module_name=self._mapped_component.module_name,
            class_name=self._mapped_component.class_name,
            comp_uuid=self._mapped_component.uuid
        )

        self._sim_data.mapped_upwind_solver_uuid = self._mapped_component.uuid
        self._sim_data.commit()

        self._query.add_component(
            do_component=do_component, actions=[self._mapped_component.get_display_options_action()]
        )
        self._query.link_item(taker_uuid=self._sim_uuid, taken_uuid=self._mapped_component.uuid)

    def _run(self):
        """Map the coverage."""
        self._sim_uuid = self._query.parent_item_uuid()

        self._unlink_components()
        self._get_sim_data()
        self._get_grid()
        self._get_coverage()

        self._make_mapped_component()
        self._get_polygons()
        self._build_mapping()

        self._add_data()
        self._write_display_options()
        self._write_locations()

        self._link_component()

        self._log.info('Done!')
