"""Module for a coverage snapper."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['SnapMode', 'Snapper']

# 1. Standard Python modules
from enum import auto, Enum
from functools import cached_property
import logging
from logging import Logger
from typing import cast, Iterable, Mapping, Optional

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import UGrid2d
from xms.constraint.grid import UGrid
from xms.data_objects.parameters import Arc, Coverage, FilterLocation, Point
from xms.guipy.dialogs.log_timer import Timer
from xms.snap import SnapExteriorArc, SnapInteriorArc, SnapPoint

# 4. Local modules
from xms.hydroas.file_io.errors import Messages as Msg, write_error


class SnapMode(Enum):
    """
    Enum describing how an arc should be snapped.

    - INTERIOR: The arc can snap to both the interior and exterior of the UGrid.
    - EXTERIOR: The arc can only snap to the exterior of the UGrid.
    - ENDPOINTS: The arc's first and last points will be snapped to the nearest points in the UGrid, and everything else
      will be discarded.
    """
    INTERIOR = auto()
    EXTERIOR = auto()
    ENDPOINTS = auto()


class Snapper:
    """Class for snapping coverages to UGrids."""
    def __init__(self, geometry: UGrid2d, logger: Optional[Logger] = None):
        """
        Initialize the snapper.

        Args:
            geometry: The geometry to snap to.
            logger: Where to log progress messages to.
        """
        self._geometry = geometry  # A UGrid2d is a UGrid in C++, but that's not exposed in Python.
        self._log = logger or logging.getLogger('xms.hydroas.snapper')
        self._timer = Timer()

    def points(self, coverage: Coverage) -> Iterable[tuple[int, int]]:
        """
        Snap all the points in a coverage.

        Args:
            coverage: Coverage containing the points to snap.

        Returns:
            Iterable of points. Each point is a tuple of `(feature_id, node_index)`, where `feature_id` is the `id`
            attribute on the point that was snapped, and `node_index` is an index into the UGrid's point list.
        """
        points = coverage.get_points(FilterLocation.LOC_DISJOINT)
        return self._get_snapped_points(points)

    def _get_snapped_points(self, points: list[Point]) -> list[tuple[int, int]]:
        """
        Snap some points to the mesh.

        Args:
            points: Points to snap.

        Returns:
            List of tuples of (id, index) where id is the id attribute of a point in the input, and index is the
            index of the node in the mesh that it snapped to. Returned list is in the same order as the input.
        """
        if not points:
            # InterpIdw.interpolate_to_points kills the interpreter on empty lists, so bail out before that happens.
            return []

        locations = points
        ids = [location.id for location in locations]

        snapped_points = self._point_snapper.get_snapped_points(locations)['id']

        result = list(zip(ids, snapped_points))
        return result

    def arcs(self, coverage: Coverage, snap_modes: Mapping[int, SnapMode]) -> Iterable[tuple[int, tuple[int, ...]]]:
        """
        Snap all the arcs in a coverage.

        It is possible that an arc will fail to snap, e.g. because it is too short, or because it jumps across disjoint
        parts of the domain and doesn't use an endpoint snap mode. In these cases, the resulting node string will be a
        single element identifying the nearest node in the mesh to the arc's start node.

        Args:
            coverage: Coverage containing arcs to snap.
            snap_modes: Mapping describing how to snap each arc. Keys are arc IDs, and values are how to snap.

        Returns:
            Iterable of snapped arcs. Each arc is a tuple of `(feature_id, node_string)`, where `feature_id` matches
            the `id` attribute of the arc that was snapped, and `node_string` is a list of indices of nodes in the UGrid
            that the arc snapped to.
        """
        snapped_arcs: list[tuple[int, tuple[int, ...]]] = []
        # It's normal for one instance of this class to be reused for multiple coverages. When that happens, the second
        # coverage will skip anything that initialized the class previously, so the timer may be arbitrarily out of date
        # at that point.
        self._timer = Timer()

        for number, arc in enumerate(coverage.arcs, start=1):
            if self._timer.report_due:
                self._log.info(f'Snapped {number} arcs...')
            snapped_arc = self._snap_arc(arc, snap_modes[arc.id])
            if not snapped_arc or len(snapped_arc) == 1:
                location = (arc.start_node.x, arc.start_node.y)
                raise write_error(Msg.degenerate_arc, location)

            snapped_arcs.append((arc.id, snapped_arc))

        return snapped_arcs

    @cached_property
    def _point_snapper(self) -> SnapPoint:
        self._log.info('Constructing point snapper...')
        snapper = SnapPoint()
        geometry = cast(UGrid, self._geometry)  # UGrid2d really is a UGrid, but it's only visible on the C++ side.
        snapper.set_grid(geometry, target_cells=False)
        self._log.info('Constructed.')
        self._timer = Timer()
        return snapper

    @cached_property
    def _interior_arc_snapper(self) -> SnapInteriorArc:
        self._log.info('Constructing interior arc snapper...')
        snapper = SnapInteriorArc()
        geometry = cast(UGrid, self._geometry)  # UGrid2d really is a UGrid, but it's only visible on the C++ side.
        snapper.set_grid(geometry, target_cells=False)
        self._log.info('Constructed.')
        self._timer = Timer()
        return snapper

    @cached_property
    def _exterior_arc_snapper(self) -> SnapExteriorArc:
        self._log.info('Constructing exterior arc snapper...')
        snapper = SnapExteriorArc()
        geometry = cast(UGrid, self._geometry)  # UGrid2d really is a UGrid, but it's only visible on the C++ side.
        snapper.set_grid(geometry, target_cells=False)
        self._log.info('Constructed.')
        self._timer = Timer()
        return snapper

    def _snap_arc(self, arc: Arc, snap_mode: SnapMode) -> tuple[int, ...]:
        """
        Snap an arc.

        Args:
            arc: The arc to snap.
            snap_mode: How to snap it.

        Returns:
            Indices of nodes in the UGrid.
        """
        if snap_mode == SnapMode.INTERIOR:
            return self._interior_arc_snapper.get_snapped_points(arc)['id']
        elif snap_mode == SnapMode.EXTERIOR:
            return self._exterior_arc_snapper.get_snapped_points(arc)['id']
        elif snap_mode == SnapMode.ENDPOINTS:
            ids_and_indexes = self._get_snapped_points([arc.start_node, arc.end_node])
            indexes = tuple(index for _id, index in ids_and_indexes)
            return indexes
        else:  # pragma: nocover
            raise AssertionError('Unknown SnapMode')
