"""Module for the `MapCoverageRunner` class."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"
__all__ = ['MapCoverageRunner']

# 1. Standard Python modules
from collections import Counter, deque
from functools import cached_property
import itertools
from pathlib import Path
from typing import Any, cast, Optional

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util, TreeNode
from xms.components.display.display_options_helper import DisplayOptionsHelper
from xms.constraint import read_grid_from_file, UGrid2d
from xms.constraint.ugrid_boundaries import UGridBoundaries
from xms.data_objects.parameters import Arc, Component, Coverage
from xms.gmi.components.gmi_component import UNINITIALIZED_COMP_ID
from xms.gmi.data_bases.coverage_base_data import CoverageBaseData
from xms.grid.geometry.geometry import point_in_polygon_2d
from xms.grid.ugrid import UGrid
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.feedback_thread import ExitError, ExpectedError, FeedbackThread
from xms.snap import SnapExteriorArc

# 4. Local modules
from xms.schism.components.coverage_component import CoverageComponent
from xms.schism.components.mapped_bc_component import MappedBcComponent
from xms.schism.data.mapped_bc_data import MappedBcData
from xms.schism.data.model import get_model
from xms.schism.external.crc import compute_crc


class MapCoverageRunner(FeedbackThread):
    """Worker thread for mapping a coverage."""
    def __init__(self, coverage_uuid: str, query: Query):
        """
        Initialize the mapper.

        Args:
            coverage_uuid: UUID of the coverage to map.
            query: Interprocess communication object.
        """
        super().__init__(query)

        self.display_text = {
            'title': 'SCHISM Map BC Coverage',
            'working_prompt': 'Mapping BC coverage, please wait...',
            'warning_prompt': 'Warning(s) encountered while mapping coverage. Review log output for more details.',
            'error_prompt': 'Error(s) encountered while mapping coverage. Review log output for more details.',
            'success_prompt': 'Successfully mapped coverage',
            'note':
                (
                    "Applied data generated by this operation is valid with the simulation's current domain mesh and "
                    "source Boundary Conditions coverage. Any further editing of the domain mesh or source coverage "
                    "may invalidate the applied boundary conditions. To ensure that the data exported by SMS is up to "
                    "date, reapply the source Boundary Conditions coverage after editing the mesh, coverage geometry, "
                    "or boundary condition arc attributes."
                ),
            'auto_load': 'Close this dialog automatically when generation is finished.'
        }

        self._coverage_uuid: str = coverage_uuid

        self._mapped_component: Optional[MappedBcComponent] = None
        self._mapped_component_directory: Optional[Path] = None
        self._exterior_snapper: Optional[SnapExteriorArc] = None
        self._project_tree: TreeNode = query.copy_project_tree()

        self._node: list[int] = []
        self._arc: list[Arc] = []
        self._nodestring: list[list[int]] = []
        self._value: list[str] = []
        self._is_island: list[bool] = []

    def _run(self):
        """Map the coverage."""
        self._unlink_old_components()

        self._make_mapped_component()
        self._get_arcs()
        self._initialize_snapper()
        self._snap_arcs()
        self._check_arcs()
        self._check_for_touching_islands()

        self._insert_land_boundaries()
        self._identify_arcs()
        self._add_data()
        self._write_display_options()

        self._link_component()

        self._log.info('Done!')

    @cached_property
    def _sim_uuid(self) -> str:
        """The UUID of the simulation."""
        return self._query.parent_item_uuid()

    @cached_property
    def _sim_tree_node(self) -> TreeNode:
        """
        The tree node for the simulation.

        (not the component)
        """
        project_tree = self._query.copy_project_tree()
        sim_node = tree_util.find_tree_node_by_uuid(project_tree, self._sim_uuid)
        return sim_node

    @cached_property
    def _domain_and_hash(self) -> tuple[str, UGrid2d]:
        """
        The domain and hash.

        Returns:
            Tuple of (hash, domain).
        """
        self._log.info('Reading domain...')
        sim_node = self._sim_tree_node
        domain_node = tree_util.descendants_of_type(
            sim_node, xms_types=['TI_MESH2D_PTR'], only_first=True, allow_pointers=True
        )
        if not domain_node:
            raise ExpectedError('Mapping a coverage requires a mapped mesh.')
        domain_uuid = domain_node.uuid
        domain = self._query.item_with_uuid(domain_uuid)
        domain_hash = compute_crc(domain.cogrid_file)
        domain = read_grid_from_file(domain.cogrid_file)

        if domain.ugrid.point_count == 0:
            raise ExpectedError('Mapping a coverage requires a mesh with elements.')

        return domain_hash, domain

    @cached_property
    def _domain(self) -> UGrid2d:
        """The domain."""
        _hash, domain = self._domain_and_hash
        return domain

    @cached_property
    def _hash(self) -> str:
        """The hash of the domain."""
        domain_hash, _domain = self._domain_and_hash
        return domain_hash

    @cached_property
    def _coverage(self) -> tuple[Coverage, CoverageBaseData]:
        self._log.info('Retrieving coverage...')
        coverage = self._query.item_with_uuid(self._coverage_uuid)
        do_comp = self._query.item_with_uuid(self._coverage_uuid, model_name='SCHISM', unique_name='CoverageComponent')
        component = CoverageComponent(do_comp.main_file)
        self._query.load_component_ids(component, arcs=True)
        return coverage, component.data

    @cached_property
    def _exterior_boundary_nodes(self) -> set[int]:
        """
        The nodes that are on the exterior boundary.

        Nodes on interior boundaries (islands, or holes) are not included.
        """
        loops = self._boundary_loops
        boundary_nodes: set[int] = set()

        inside = 1

        # UGridBoundaries finds us all the boundary loops, including those that are islands. SCHISM cares about the
        # difference between the exterior land boundary and the interior island ones, so we need to separate them.
        # Additionally, there may be multiple exterior land boundaries, as in the case of a disjoint mesh, which could
        # happen if you're building some land feature across a lake or something.
        #
        # Our algorithm is to pick a representative point from each polygon and check if it is in any other polygon.
        # Then we assume all the other points are the same. Any polygon that is inside another is considered an island.
        #
        # That last point means that a pond on an island is treated as an island, which might be incorrect, but that's
        # probably unusual enough that nobody will care.
        for left_loop in loops:
            is_inside = False
            test_point = left_loop['location'][0]
            for right_loop in loops:
                if point_in_polygon_2d(right_loop['location'], test_point) == inside:
                    is_inside = True
                    break
            if not is_inside:
                boundary_nodes.update(left_loop['id'])

        return boundary_nodes

    @cached_property
    def _boundary_loops(self) -> list[dict[str, Any]]:
        """
        List of boundary loops.

        Each loop is a dict where loop['id'] is a list of IDs defining the loop, and loop['location'] is a list of
        (x, y, z) tuples defining the loop.
        """
        boundary_finder = UGridBoundaries(self._domain)
        finder_loops = boundary_finder.get_loops()
        keys = [key for key in finder_loops]
        loops = [finder_loops[key] for key in keys]
        for loop in loops:
            # SCHISM has a slight preference for CCW order. These are numpy arrays and tuples, so no .reverse() method.
            loop['location'] = loop['location'][::-1]
            loop['id'] = loop['id'][::-1]
        return loops

    @cached_property
    def _boundary_id_loops(self) -> list[list[int]]:
        """List of boundary loops, by IDs."""
        loops = [loop['id'] for loop in self._boundary_loops]
        return loops

    @cached_property
    def _all_boundary_nodes(self) -> set[int]:
        """
        The nodes that are on any boundary.

        Unlike self._exterior_boundary_nodes, this includes nodes on island/hole boundaries.
        """
        loops = self._boundary_loops
        all_loops = (loop['id'] for loop in loops)
        chained_nodes = itertools.chain.from_iterable(all_loops)
        unique_nodes = set(chained_nodes)
        return unique_nodes

    def _make_mapped_component(self):
        """Create an empty mapped component."""
        self._log.info('Creating mapped coverage...')
        data = MappedBcData()
        data.close()
        self._mapped_component = MappedBcComponent(data.main_file)

        self._mapped_component_directory = Path(self._mapped_component.main_file).parent

    def _get_arcs(self):
        """Get arcs and their values from the coverage being mapped."""
        self._log.info('Retrieving arcs...')
        coverage, data = self._coverage
        arcs = coverage.arcs
        arc_section = get_model().arc_parameters
        arc_section.group('open').is_active = True
        default_values = arc_section.extract_values()

        for arc in arcs:
            component_id = data.component_id_map[TargetType.arc].get(arc.id, UNINITIALIZED_COMP_ID)
            # If the user creates an arc and doesn't assign it, then its values will come back empty. We want all arcs
            # to be implicitly open, so we'll override unassigned arcs here.
            value = data.feature_values(TargetType.arc, component_id) or default_values
            self._arc.append(arc)
            self._value.append(value)

    def _initialize_snapper(self):
        """Initialize the snapper."""
        self._log.info('Initializing exterior arc snapper...')
        self._exterior_snapper = SnapExteriorArc()
        self._exterior_snapper.set_grid(cast(UGrid, self._domain), False)

    def _snap_arcs(self):
        """Snap arcs to the geometry."""
        self._log.info('Snapping arcs...')
        for arc in self._arc:
            node_ids = self._exterior_snapper.get_snapped_points(arc)['id']
            if len(node_ids) > 1 and node_ids[0] == node_ids[-1]:
                # Arcs that make a full loop around a hole can end up with repeated points at the nodes, but SCHISM
                # doesn't like when points are written twice, so we'll discard the last one.
                node_ids = node_ids[:-1]  # Tuples have no .pop().
            self._nodestring.append(node_ids)

    def _check_arcs(self):
        """
        Check that the arcs in the coverage are all valid.

        Arcs are invalid if they are degenerate, or snap to the same node as another arc.
        """
        valid = True
        boundary_nodes = {node: 0 for node in self._all_boundary_nodes}
        for arc, nodestring in zip(self._arc, self._nodestring):
            if len(nodestring) < 2:
                self._log.error(f'Arc {arc.id} was degenerate after snapping.')
                valid = False
                continue

            other_arc, node = _intersects_other_arc(arc.id, nodestring, boundary_nodes)
            if other_arc != -1:
                self._log.error(f'Arcs {arc.id} and {other_arc} both snap to node {node}.')
                valid = False
                continue

        if not valid:
            raise ExitError()

    def _check_for_touching_islands(self):
        """Check if any islands are touching and log a warning if so."""
        loops = self._boundary_id_loops
        chained_loops = itertools.chain.from_iterable(loops)
        counter = Counter(chained_loops)
        duplicates = (node for node, count in counter.items() if count > 1)
        for duplicate in duplicates:
            duplicate: int
            self._log.warning(f'Multiple islands meet at node {duplicate + 1}. Conveyance will be blocked.')

    def _insert_land_boundaries(self):
        """Make any parts of the mesh boundary that have no arc be land boundaries."""
        self._log.info('Inserting closed boundary loops')

        open_nodes = set(itertools.chain.from_iterable(self._nodestring))
        loops = self._boundary_loops

        for loop in loops:
            self._insert_land_boundaries_for_loop(open_nodes, loop['id'])

    def _insert_land_boundaries_for_loop(self, open_nodes: set[int], loop: list[int]):
        """
        Insert the implicit land boundaries for a single loop.

        Args:
            open_nodes: Set of all open nodes.
            loop: Node IDs describing the loop to insert boundaries for.
        """
        assert loop[0] != loop[-1]

        if all(node in open_nodes for node in loop):  # The entire loop is used. Nothing to insert.
            return

        if all(node not in open_nodes for node in loop):  # The entire loop is unused. Insert everything.
            self._nodestring.append(loop)
            self._value.append('')
            return

        # From here on, the loop is partially used.
        loop = deque(loop)

        # Skip forward to a transition.
        while (loop[0] in open_nodes) == (loop[-1] in open_nodes):
            loop.append(loop.popleft())

        arcs = []
        arc = []
        while loop:
            if arc and loop[0] in open_nodes:  # Switching from closed to open
                arcs.append(arc)
                arc = []
                loop.popleft()
            elif loop[0] in open_nodes:  # Continuing an open one
                loop.popleft()
            else:
                arc.append(loop.popleft())
        if arc:
            arcs.append(arc)

        self._nodestring.extend(arcs)
        self._value.extend([''] * len(arcs))

    def _identify_arcs(self):
        """Initialize self._is_island with flags for each land boundary."""
        self._log.info('Identifying arcs...')
        boundary_nodes = self._exterior_boundary_nodes
        for i in range(len(self._nodestring)):
            self._is_island.append(self._nodestring[i][0] not in boundary_nodes)

    def _add_data(self):
        """Put the arc data in self._mapped_component."""
        self._log.info('Initializing mapped coverage...')

        data: MappedBcData = self._mapped_component.data

        for nodestring, value, on_boundary in zip(self._nodestring, self._value, self._is_island):
            nodestring = [i + 1 for i in nodestring]  # convert 0-based to 1-based
            if value:
                data.add_open_arc(nodestring, value)
            else:
                data.add_closed_arc(nodestring, on_boundary)

        data.domain_hash = self._hash
        data.commit()

    def _write_display_options(self):
        """Write the component's new display options."""
        locations = self._domain.ugrid.locations
        lines = []
        for nodestring, value in zip(self._nodestring, self._value):
            if not value:
                continue
            point_string = [locations[index] for index in nodestring]
            lines.append(point_string)
        _coverage, data = self._coverage
        with DisplayOptionsHelper(data.main_file) as helper:
            helper.apply_to_drawing(self._mapped_component.main_file)
            helper.draw_lines('Open', lines)

    def _unlink_old_components(self):
        """Unlink any existing components that need to be removed."""
        self._query.unlink_item(self._sim_uuid, self._coverage_uuid)

        existing_mapped_coverages = tree_util.descendants_of_type(self._project_tree, unique_name='MappedBcComponent')
        for coverage in existing_mapped_coverages:
            self._query.unlink_item(self._sim_uuid, coverage.uuid)

        existing_mapped_tides = tree_util.descendants_of_type(self._project_tree, unique_name='MappedTidalComponent')
        if existing_mapped_tides:
            self._log.warning(
                'Existing mapped tides were removed. '
                'Mapped tides are only valid for the domain and coverage they were mapped to.'
            )
        for tide_sim in existing_mapped_tides:
            self._query.unlink_item(self._sim_uuid, tide_sim.uuid)

    def _link_component(self):
        """Link the new component to the simulation."""
        self._log.info('Sending data...')
        coverage, _component = self._coverage
        name = f'{coverage.name} (applied)'
        do_component = Component(
            main_file=self._mapped_component.main_file,
            name=name,
            model_name='SCHISM',
            unique_name='MappedBcComponent',
            comp_uuid=self._mapped_component.uuid
        )

        self._query.add_component(
            do_component=do_component, actions=[self._mapped_component.get_display_options_action()]
        )
        self._query.link_item(taker_uuid=self._sim_uuid, taken_uuid=self._mapped_component.uuid)


def _intersects_other_arc(arc_id: int, nodestring: list[int], boundary_nodes: dict[int, int]) -> tuple[int, int]:
    """
    Check if one arc snaps to the same node as any other arc.

    Args:
        arc_id: ID of the arc being checked. Used to update `boundary_nodes`.
        nodestring: List of nodes comprising the arc to check.
        boundary_nodes: Dictionary of `(node_id -> arc_id)`. Should be initialized with all node IDs of interest mapping
           to 0s. Will be updated so that nodes on the provided arc now map to the provided arc.

    Returns:
        Tuple of `(arc_id, node_id)`, where `arc_id` is the ID of another arc that has a node in common with the input,
        and `node_id` is one of the nodes they have in common. If no such arc exists, then `(-1, -1)` is returned.
    """
    for node in nodestring:
        if boundary_nodes[node] != 0:
            return boundary_nodes[node], node
        boundary_nodes[node] = arc_id
    return -1, -1
