"""Merges polygons based on a classification."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
from typing import Optional, Sequence, TypeAlias

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import UGrid2d
from xms.data_objects.parameters import Coverage, Projection
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.coverage.grid.polygon_coverage_builder import PolygonCoverageBuilder
from xms.coverage.polygons.polygon_orienteer import poly_area_x2

# The index of a point in the UGrid.
PointIndex: TypeAlias = int
# A boundary of a polygon. Either the outer boundary or an inner hole. Points may be clockwise or counterclockwise.
# The builder behaves the same way regardless.
Ring: TypeAlias = Sequence[PointIndex]
# An entire polygon. The first ring is required and defines the outer boundary. The other rings, if present, define the
# inner holes of the polygon.
Polygon: TypeAlias = Sequence[Ring]
# One or more polygons covering all the cells that have the same dataset value. The builder only creates Regions for
# dataset values it finds, so this will normally have at least one element. There may be multiple polygons if the cells
# define disjoint regions.
Region: TypeAlias = Sequence[Polygon]
# A mapping from dataset value to region defined by that dataset value.
RegionMap: TypeAlias = dict[int, list[Region]]


class GridCellToPolygonCoverageBuilder:
    """Given a grid and integer cell dataset, computes polygonal boundaries of areas with the same dataset value.

    - output is a data_objects.parameters.Coverage
    - create_polygons_and_build_coverage: Builds the data_objects Coverage.
    - find_polygons: Just finds the polygons
    """
    def __init__(
        self,
        co_grid: UGrid2d,
        dataset_values: Sequence[int],
        projection: Projection,
        coverage_name: str,
        null_value: Optional[int] = None,
        logger=None
    ):
        """Initializes the class.

        Args:
            co_grid (xms.constraint.ugrid2d.UGrid2d): The grid.
            dataset_values: Cell dataset of integers. This list is parallel to the cell list in the UGrid
                and identifies the dataset value to associate with that cell.
            projection: The map projection.
            coverage_name: Name to be given to the new coverage.
            null_value (Optional[int]): A dataset value to exclude from polygon building. Any grid cells with this
                assigned value will not be merged into the polygons of the output coverage.
            logger (Optional[Logger]): The logger to use. If not provided will be the xms.coverage logger.
        """
        self._ugrid: UGrid = co_grid.ugrid
        self._dataset_values = dataset_values
        self._projection = projection
        self._coverage_name = coverage_name
        self._null_value = null_value
        self._logger = logger if logger is not None else logging.getLogger('xms.coverage')

        # Stuff for extracting polygons (list of point ids)
        self._unvisited_cells: set[PointIndex] = set()  # Set of unvisited cells (cell indexes)
        self._ugrid_pt_locs = self._ugrid.locations  # xyz locations of grid points

        # Stuff for building a coverage
        self.dataset_polygon_ids = {}  # Dict of dataset value -> list of Polygon IDs
        self._next_pt_id = 1
        self._next_arc_id = 1
        self._next_poly_id = 1
        self._pt_hash = {}  # Dict of point idx -> Point
        self._arc_hash = {}  # Dict of (node id1, node id2) -> Arc

    def create_polygons_and_build_coverage(self) -> Coverage:
        """Creates a Coverage with Polygons made from regions where cells have the same dataset value.

        Returns:
            data_objects.parameters.Coverage: See description
        """
        self._logger.info('Creating polygons from grid cell assignments...')
        regions = self.find_polygons()
        if self._null_value is not None:
            regions.pop(self._null_value, None)
        return self._build_coverage(regions)

    def find_polygons(self) -> RegionMap:
        """Finds boundaries of areas where cells have the same dataset value.

        Returns:
            Mapping from dataset value to polygons defining the region covered by that value.
        """
        self._logger.info('Finding contiguous cells with same category assignments...')
        self._unvisited_cells = {s for s in range(self._ugrid.cell_count)}
        region = {}  # Dict of dataset value -> region

        # Visit all cells in the grid
        while self._unvisited_cells:
            start_cell = self._unvisited_cells.pop()
            value = self._dataset_values[start_cell]

            # Aggregate all adjacent cells with same dataset value and add edges
            # to pt_map
            pt_map = {}  # Dict of point index -> set of edge adjacent point indices
            stack = [start_cell]  # Stack of cell indexes used to recurse on adjacent cells
            self._unvisited_cells.add(start_cell)
            while stack:
                cell = stack.pop()
                if cell not in self._unvisited_cells:
                    continue
                self._add_needed_adjacent_edges(cell, value, stack, pt_map)
                self._mark_cell_as_visited(cell)

            # Get polygons from edges and save them
            if value not in region:
                self._logger.info(f'Found cells with category assignment: {value}')
                region[value] = []
            polygon = self._polygon_from_edges(pt_map)
            region[value].append(polygon)

        return region

    def _add_needed_adjacent_edges(self, cell, value, stack, pt_map):
        """Adds outer edges to pt_map and edges between adjacent cells with a different dataset value.

        Also may grow the stack of more adjacent cells with the same dataset value are found.

        Args:
            cell (int): Index of a cell.
            value (int): Current dataset value of region we are aggregating.
            stack (list of int): Stack of adjacent cells sharing the same dataset value.
            pt_map ({pt_idx: {adjacent_point_indices}}): Mapping of point indices to set of adjacent point indices
        """
        # Add edge if adjacent element doesn't exist or it's value doesn't match
        cell_edges = self._ugrid.get_cell_edges(cell)
        for edge_idx in range(len(cell_edges)):
            adj_cell = self._ugrid.get_cell_2d_edge_adjacent_cell(cell, edge_idx)
            if adj_cell >= 0 and self._dataset_values[adj_cell] == value:  # values match
                # Add adj_cell to stack if we haven't visited it already
                if adj_cell in self._unvisited_cells:
                    stack.append(adj_cell)
            else:
                # Add edge nodes to point map
                node0 = cell_edges[edge_idx][0]
                node1 = cell_edges[edge_idx][1]

                if node0 not in pt_map:
                    pt_map[node0] = set()
                pt_map[node0].add(node1)

                if node1 not in pt_map:
                    pt_map[node1] = set()
                pt_map[node1].add(node0)

    def _mark_cell_as_visited(self, cell):
        """Marks the cells as visited by removing it from the set of unvisited cells.

        Args:
            cell (int): Index of a cell.
        """
        self._unvisited_cells.discard(cell)

    def _polygon_from_edges(self, pt_map) -> Polygon:
        """Returns the polygon defined by the pt_map.

        Args:
            pt_map (Dict of point index -> set of edge adjacent point indices):

        Returns:
            polygon
        """
        polygon = self._polygon_from_point_map(pt_map)
        self._sort_outer_ring_first(polygon)
        return polygon

    def _sort_outer_ring_first(self, polygon: Polygon):
        """Finds the outer ring and puts it first in the polygon.

        Outer ring is the one with the greatest area.

        Args:
            polygon: The polygon to sort the rings in.

        Returns:
            polygon
        """
        if len(polygon) < 2:
            return

        max_area = -1.0
        max_ring = 0
        for i, ring in enumerate(polygon):
            point_locations = [self._ugrid_pt_locs[point] for point in ring]
            area = abs(poly_area_x2(point_locations))
            if abs(area) > max_area:
                max_area = area
                max_ring = i

        if max_ring != 0:
            polygon[max_ring], polygon[0] = polygon[0], polygon[max_ring]

    def _polygon_from_point_map(self, pt_map) -> Polygon:
        """Returns the polygon defined by the pt_map.

        Args:
            pt_map (Dict of point index -> set of edge adjacent point indices):

        Returns:
            The polygon.
        """
        polygon = []
        while pt_map:
            list_index = {}  # Index of point in ring, so we know if we've made a loop
            ring = []  # List of point indexes forming a loop

            # Start ring with arbitrary first point in map
            next_point = next(iter(pt_map))  # Get any point in the map
            ring.append(next_point)
            list_index[next_point] = len(ring) - 1

            done = False
            while not done:

                # Get next point
                last_point = ring[-1]
                next_point = self._next_point(last_point, pt_map)

                # Cleanup
                pt_map[last_point].remove(next_point)
                if not pt_map[last_point]:
                    del pt_map[last_point]
                pt_map[next_point].remove(last_point)
                if not pt_map[next_point]:
                    del pt_map[next_point]

                # Stop if we're back to the beginning
                if next_point == ring[0]:
                    done = True
                elif next_point in list_index:
                    # We've looped back on ourselves. Cut off loop to form a ring
                    new_ring = ring[list_index[next_point]:]
                    new_ring.append(next_point)  # Repeat first point as the last point
                    polygon.append(new_ring)
                    del ring[list_index[next_point]:]  # strip the loop from the current ring

                ring.append(next_point)
                list_index[next_point] = len(ring) - 1

            polygon.append(ring)
        return polygon

    @staticmethod
    def _next_point(last_point, pt_map) -> int:
        """Returns the next point we will visit starting from the last point using the pt_map.

        There may be more than one next point to choose from. We just pick the next one in
        the set.

        Args:
            last_point (int): Point index.
            pt_map (Dict of point index -> set of edge adjacent point indices):

        Returns:
            next_point (int): Point index.
        """
        point_set = pt_map[last_point]
        next_point = -1
        for point in point_set:
            if point != last_point:
                next_point = point
                break
        return next_point

    def _build_coverage(self, region_map: RegionMap) -> Coverage:
        """Creates a data_objects Coverage with the Polygons and Arcs.

        Args:
            region_map: Mapping from dataset value to polygons covering that value.

        Returns:
            data_objects Coverage.
        """
        self._logger.info('Merging contiguous areas with same category assignment into coverage polygons...')
        builder = PolygonCoverageBuilder(self._ugrid_pt_locs, self._projection, self._coverage_name, self._logger)
        coverage = builder.build_coverage(region_map)
        self.dataset_polygon_ids = builder.dataset_polygon_ids
        return coverage
