"""VoronoiUGridFromUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import numpy as np
from scipy.spatial import Voronoi
import shapely
from shapely.geometry import MultiPolygon, Polygon

# 3. Aquaveo modules
from xms.constraint.grid import GridType
from xms.constraint.ugrid_boundaries import UGridBoundaries
from xms.constraint.ugrid_builder import UGridBuilder
from xms.extractor.ugrid_2d_data_extractor import UGrid2dDataExtractor
from xms.grid.geometry.geometry import polygon_area_2d
from xms.grid.triangulate import Tin
from xms.grid.ugrid import UGrid as XmUGrid
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_INPUT_GRID = 0
ARG_VORONOI_GRID = 1


class VoronoiUGridFromUGridTool(Tool):
    """Tool to convert UGrids to Voronoi grids."""
    HECRAS_FACE_SPLIT_RATIO = 1.2  # From HEC-RAS developers, subject to change
    MAXIMUM_CELL_EDGES = 8
    FIRST_PASS = 0
    FIRST_PASS_RETRY = 1
    FINAL_PASS = 2

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Voronoi UGrid from UGrid')
        self._input_cogrid = None
        self._original_geometry = None  # The XmUGrid
        self._input_elevations = None
        self._extractor = None
        self._original_loops = None
        self._pt_map = {}
        self._pass_number = self.FIRST_PASS
        self._cells_with_too_many_edges = []  # Only care if polygon with more than 8 sides on the last pass.

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid'),
            # self.string_argument(name='precision_correction', description='Precision correction', value='None',
            #                      choices=['None', 'Re-triangulate', 'Adaptive correction']),
            self.grid_argument(name='voronoi_grid', description='Output grid name', io_direction=IoDirection.OUTPUT),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        # Validate primary and secondary grids are specified and 2D
        arg_errors = {}
        self._validate_input_grid(arg_errors, arguments[ARG_INPUT_GRID])
        return arg_errors

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        self._input_cogrid = self.get_input_grid(argument.text_value)
        if not self._input_cogrid:
            errors[key] = 'Could not read grid.'
        else:
            # If we have a grid, perform all the checks
            grid_errors = []
            # if not self._input_cogrid.check_all_cells_2d():  # This did not work.
            if self._input_cogrid.grid_type in [GridType.ugrid_3d, GridType.rectilinear_3d, GridType.quadtree_3d]:
                grid_errors.append('Must have all 2D cells.')
            if not self._input_cogrid.check_contiguous_cells_2d():
                grid_errors.append('Input grid cannot have disjoint regions.')
            if self._input_cogrid.check_has_disjoint_points():
                grid_errors.append('Input grid cannot have disjoint points.')
            error = '\n'.join(grid_errors)
            if error:
                errors[key] = error

    def _finite_polygons_from_voronoi_diagram(self, vor):
        """Construct finite regions from infinite Voronoi regions in a 2D diagram.

        https://stackoverflow.com/questions/36063533/clipping-a-voronoi-diagram-python

        Args:
            vor (Voronoi): Input diagram generated by the scipy.spatial.Voronoi class

        Returns:
            (list, indices, coordinates):
                regions: list of tuple, vertices:list of tuple:

                regions: Indices of vertices in each revised Voronoi regions

                vertices: Coordinates for revised Voronoi vertices. Same as
                coordinates of input vertices, plus 'points at infinity'.
        """
        new_regions = []
        new_vertices = vor.vertices.tolist()
        center = vor.points.mean(axis=0)
        # Compute the length along an infinite edge to create new points
        radius = vor.points.ptp().max() * 2

        # Construct a map containing all ridges for a given point
        all_ridges = {}
        for (p1, p2), (v1, v2) in zip(vor.ridge_points, vor.ridge_vertices):
            all_ridges.setdefault(p1, []).append((p2, v1, v2))
            all_ridges.setdefault(p2, []).append((p1, v1, v2))

        # Build a list of new, finite regions
        for p1, region in enumerate(vor.point_region):
            vertices = vor.regions[region]
            if all(v >= 0 for v in vertices):
                # Finite region
                new_regions.append(vertices)
                continue

            # Reconstruct an infinite region
            ridges = all_ridges[p1]
            new_region = [v for v in vertices if v >= 0]
            for p2, v1, v2 in ridges:
                if v2 < 0:
                    v1, v2 = v2, v1    # pragma no cover - likely unnecessary - scipy seems to always make it ordered
                if v1 >= 0:
                    # Finite ridge, already in the region
                    continue

                # Compute the missing endpoint of an infinite ridge
                t = vor.points[p2] - vor.points[p1]  # tangent
                t /= np.linalg.norm(t)
                n = np.array([-t[1], t[0]])  # normal
                midpoint = vor.points[[p1, p2]].mean(axis=0)
                direction = np.sign(np.dot(midpoint - center, n)) * n
                far_point = vor.vertices[v2] + direction * radius
                new_region.append(len(new_vertices))
                new_vertices.append(far_point.tolist())

            # Sort region CCW
            vs = np.asarray([new_vertices[v] for v in new_region])
            c = vs.mean(axis=0)
            angles = np.arctan2(vs[:, 1] - c[1], vs[:, 0] - c[0])
            new_region = np.array(new_region)[np.argsort(angles)]
            new_regions.append(new_region.tolist())
        return new_regions, np.asarray(new_vertices)

    def _set_point_elevations(self):
        """Extract elevations from the input UGrid at the output UGrid points.

        Returns:
            list: List of the UGrid points as x,y,z tuples
        """
        all_active = [1] * self._original_geometry.point_count
        self._extractor.set_grid_point_scalars(self._input_elevations, all_active, 'points')
        output_points = list(self._pt_map.keys())
        self._extractor.extract_locations = output_points
        elevations = self._extractor.extract_data()
        return [(xy[0], xy[1], z_coord) for xy, z_coord in zip(output_points, elevations)]

    def remove_boundary_midpoints_internal(self, voronoi_ugrid):
        """Remove co-linear midpoints on the boundary of the output UGrid.

        This gets called by mock to avoid infinite recursion.

        Args:
            voronoi_ugrid (xms.grid.Grid): The output UGrid

        Returns:
            xms.grid.Grid: The output UGrid with boundary midpoints removed
        """
        delete_points, update_cells = self._find_points_to_remove(voronoi_ugrid)

        if not delete_points:
            return voronoi_ugrid  # Didn't delete any points so don't need to rebuild the XmUGrid

        # Build the new point list
        new_points = []
        old_to_new_point_idxs = {}
        for pt_idx, pt in enumerate(voronoi_ugrid.locations):
            if pt_idx not in delete_points and voronoi_ugrid.get_point_adjacent_cells(pt_idx):
                old_to_new_point_idxs[pt_idx] = len(new_points)
                new_points.append(pt)

        new_cellstream = self._rebuild_cellstream(delete_points, old_to_new_point_idxs, update_cells,
                                                  voronoi_ugrid.cellstream)
        return XmUGrid(new_points, new_cellstream)

    def remove_boundary_midpoints(self, voronoi_ugrid):
        """Remove co-linear midpoints on the boundary of the output UGrid (for test mocking).

        This gets mocked by tests to avoid infinite recursion.

        Args:
            voronoi_ugrid (xms.grid.Grid): The output UGrid

        Returns:
            xms.grid.Grid: The output UGrid with boundary midpoints removed
        """
        return self.remove_boundary_midpoints_internal(voronoi_ugrid)

    def _find_points_to_remove(self, voronoi_ugrid):
        """Find mid-side points along the boundary that should be removed.

        Notes:
            Find all points on the boundary that are only connected to one cell, they are a candidate for removal. Only
            do this on the exterior boundary. Messes things up when we remove midpoints from interior holes. This is
            mainly a step to make HEC-RAS happy, and it doesn't allow holes anyway.

            This is the logic were told HEC-RAS uses for determining whether to keep a mid-side point.

            1) Compute A1 - the cell polygon area

            2) Compute A2 - the cell polygon area without any mid-side points

            3) If A1/A2 > 1.2, we may not want to remove it. If A1/A2 <= 1.2, remove it and we are done.
              3a) Loop through each mid-side point on this face and recompute A2 with this one additional point

              3b) Keep the point that adds the most area and remove all other mid-side points on this face

        Args:
            voronoi_ugrid (xms.grid.Grid): The output UGrid

        Returns:
            tuple(set, set): Indices of the points to remove, and indices of the cells that need to be updated
        """
        self.logger.info('Finding boundary of the output grid...')
        boundary = UGridBoundaries(voronoi_ugrid)
        loops = boundary.get_loops()
        delete_points = set()
        update_cells = set()
        exterior_idx = 0
        for pt_idx in loops[exterior_idx]['id']:
            adjacent_cells = voronoi_ugrid.get_point_adjacent_cells(pt_idx)
            if len(adjacent_cells) == 1:  # Boundary node only connected to one cell, it is a candidate for removal.
                # Compute the area of the polygon with all points.
                point_indices = voronoi_ugrid.get_cell_points(adjacent_cells[0])
                point_locations = voronoi_ugrid.get_cell_locations(adjacent_cells[0])
                polygon_area = polygon_area_2d(point_locations)
                # Compute the area of the polygon without mid-side points.
                corner_points = [
                    location for i, location in enumerate(point_locations)
                    if len(voronoi_ugrid.get_point_adjacent_cells(point_indices[i])) > 1
                ]
                corner_polygon_area = polygon_area_2d(corner_points)

                # If ratio of the two areas is above the HEC-RAS defined limit, delete the point and mark this polygon
                # as a non-candidate because we do not need to check any other mid-side points of the polygon.
                if not corner_polygon_area or polygon_area / corner_polygon_area <= self.HECRAS_FACE_SPLIT_RATIO:
                    delete_points.add(pt_idx)
                    update_cells.add(adjacent_cells[0])
                    continue

                face_midpoints = self._find_face_midpoints(voronoi_ugrid, pt_idx)

                if len(face_midpoints) == 1:
                    # We are the only mid-side point on this face. Keep this point.
                    continue

                # Compute the area of the polygon with only corner nodes and one of the mid-side nodes. The one that
                # adds the most area is the one we keep.
                my_polygon = [
                    location for i, location in enumerate(point_locations)
                    if point_indices[i] == pt_idx or point_indices[i] not in face_midpoints
                ]
                my_area = polygon_area_2d(my_polygon)
                for midpoint in face_midpoints:
                    if midpoint == pt_idx:
                        continue  # Skip ourselves
                    midpoint_polygon = [
                        location for i, location in enumerate(point_locations)
                        if point_indices[i] == midpoint or point_indices[i] not in face_midpoints
                    ]
                    midpoint_polygon_area = polygon_area_2d(midpoint_polygon)
                    if midpoint_polygon_area > my_area:
                        # There is another midpoint on this face that adds more area, delete this point.
                        delete_points.add(pt_idx)
                        update_cells.add(adjacent_cells[0])
                        break

        return delete_points, update_cells

    def _find_face_midpoints(self, voronoi_ugrid, face_midpoint_idx):
        """Find all the mid-side points on a face given one mid-side point on the face.

        Args:
            voronoi_ugrid (xms.grid.Grid): The output UGrid
            face_midpoint_idx:

        Returns:
            set: The indices of the mid-side points on the same face as face_midpoint_idx
        """
        # We are going to either keep this point, or one of its neighbors on the same face.
        face_midpoints = {face_midpoint_idx}
        adjacent_points = voronoi_ugrid.get_point_adjacent_points(face_midpoint_idx)  # 2 neighbors - left and right
        for neighbor in adjacent_points:  # March along the face in both directions.
            found_next_face = False
            current_point = neighbor
            while not found_next_face:  # March along the face in this direction until we find a corner point.
                if len(voronoi_ugrid.get_point_adjacent_cells(current_point)) == 1:
                    # Found a neighboring mid-side point, keep moving along the face
                    face_midpoints.add(current_point)
                    # Find the adjacent point that we have not processed yet (should only be two).
                    neighbor_adjacent_points = voronoi_ugrid.get_point_adjacent_points(current_point)
                    for neighbor_adjacent_point in neighbor_adjacent_points:
                        if neighbor_adjacent_point not in face_midpoints:
                            current_point = neighbor_adjacent_point
                            break
                else:  # This point is adjacent to more than one cell, so it is an endpoint of our face.
                    found_next_face = True
        return face_midpoints

    def _rebuild_cellstream(self, deleted_points, old_to_new_point_idxs, cells_to_update, old_cellstream):
        """Rebuild the Voronoi UGrid's cellstream after removing points.

        Args:
            deleted_points (set): Indices of the points that were removed
            old_to_new_point_idxs (dict): Mapping of old point indices to new point indices
            cells_to_update (set): Indices of the cells that had a point removed
            old_cellstream (list): The UGrid's cellstream before points were removed

        Returns:
            list: The new cellstream for the UGrid
        """
        # Build the new cellstream
        self._has_too_many_edges = False
        new_cellstream = []
        stream_idx = 1  # Skip the cell type
        cell_idx = 0
        while stream_idx < len(old_cellstream):
            num_pts = old_cellstream[stream_idx]
            stream_idx += 1
            if cell_idx in cells_to_update:  # Cell had point removed
                pt_ids = [
                    old_to_new_point_idxs[old_cellstream[stream_idx + i]] for i in range(num_pts) if
                    old_cellstream[stream_idx + i] not in deleted_points
                ]
            else:
                pt_ids = [old_to_new_point_idxs[old_cellstream[stream_idx + i]] for i in range(num_pts)]
            stream_idx = stream_idx + num_pts + 1  # Skip the cell type
            cell_idx += 1
            num_new_pts = len(pt_ids)
            if num_new_pts < 3:
                continue
            if self._pass_number == self.FINAL_PASS and num_new_pts > self.MAXIMUM_CELL_EDGES:
                self._cells_with_too_many_edges.append(cell_idx)  # Store the 1-based index for warning message
            new_cellstream.extend([7, num_new_pts])  # XMU_POLYGON = 7
            new_cellstream.extend(pt_ids)
        return new_cellstream

    def outer_polygon(self, exterior_idx):
        """Get the outer polygon.

        Args:
            exterior_idx (int): The exterior index.

        Returns:
            (List[Tuple[float, float]]): The outer polygon points.
        """
        outer = self._original_loops[exterior_idx]['location'][:, :-1]  # Strip Z-coordinates off for now
        return outer

    def inner_polygons(self, exterior_idx):
        """Get the inner polygon.

        Args:
            exterior_idx (int): The exterior index.

        Returns:
            (List[Tuple[float, float]]): The outer polygon points.
        """
        inner = [self._original_loops[i]['location'][:, :-1] for i in self._original_loops if i != exterior_idx]
        return inner

    def _retriangulate_grid(self, intermediate_grid):
        """After converting to a Voronoi grid, retriangulate and reconvert for a more HEC-RAS stable result.

        Args:
            intermediate_grid (xms.grid.Grid): The input grid geometry after it has been through one Voronoi translation

        Returns:
            xms.grid.Grid: The intermediate grid retriangulated at cell centers
        """
        tin = Tin(points=[intermediate_grid.get_cell_centroid(i)[1] for i in range(intermediate_grid.cell_count)])
        tin.triangulate()

        tin_tris = tin.triangles
        triangles = [  # Unstack the triangle array for convenience
            (tin_tris[idx], tin_tris[idx + 1], tin_tris[idx + 2]) for idx in range(0, len(tin_tris), 3)
        ]
        cellstream = []
        for tri in triangles:  # 5 = XMU_TRIANGLE
            cellstream.extend([5, 3, tri[0], tri[1], tri[2]])
        return XmUGrid(tin.points, cellstream)

    def _generate_voronoi_grid(self, input_ugrid):
        """Try generating the Voronoi grid with specified set of options.

        Args:
            input_ugrid (xms.grid.Grid): The input grid geometry

        Returns:
            xms.grid.Grid: The output Voronoi grid, or None if generation failed
        """
        # bounding box of the input grid as a polygon
        extents = input_ugrid.extents
        dx = extents[0][0] - extents[1][0]
        dy = extents[0][1] - extents[1][1]
        dist = math.sqrt(dx**2 + dy**2)
        pts = [(extents[0][0] - dist, extents[0][1] - dist, 0.0),
               (extents[1][0] + dist, extents[0][1] - dist, 0.0),
               (extents[1][0] + dist, extents[1][1] + dist, 0.0),
               (extents[0][0] - dist, extents[1][1] + dist, 0.0)]
        # Generate the Voronoi diagram, and convert it to finite polygons
        locations = input_ugrid.locations
        locations = np.concatenate((locations, pts))
        xy_locations = locations[:, :-1]  # Drop Z-coordinate
        if self._pass_number == self.FIRST_PASS:
            self._input_elevations = locations[:, 2]

        # Joggle the input locations on retry to avoid precision errors.
        qhull_options = 'QJ' if self._pass_number == self.FIRST_PASS_RETRY else 'Qbb Qc Qz'
        try:
            voronoi = Voronoi(xy_locations, qhull_options=qhull_options)
            self.logger.info('Constructing finite polygons from diagram...')
            regions, vertices = self._finite_polygons_from_voronoi_diagram(voronoi)

            # Clip to the boundary of the input UGrid
            self.logger.info('Clipping to boundary of the input grid...')
            exterior_idx = 0
            outer = self.outer_polygon(exterior_idx)
            inner = self.inner_polygons(exterior_idx)
            cellstream = clip_output_to_boundary(outer, inner, regions[:-4], vertices, self._pt_map)

            # Set point elevations on the output ugrid
            self.logger.info('Computing point elevations of output grid...')
            if self._pass_number == self.FINAL_PASS:
                elev_pts = self._set_point_elevations()
            else:
                elev_pts = list(self._pt_map.keys())
            xmugrid = XmUGrid(elev_pts, cellstream)

            # Remove midpoints on boundary if enabled
            self.logger.info('Removing co-linear midpoints from boundaries...')
            return self.remove_boundary_midpoints(xmugrid)
        except shapely.errors.TopologicalError:
            if self._pass_number == self.FIRST_PASS:
                self.logger.warning(
                    'First pass of generating the Voronoi grid failed. Attempting to resolve precision errors...'
                )
                self._pass_number = self.FIRST_PASS_RETRY
                return self._generate_voronoi_grid(input_ugrid)
            elif self._pass_number == self.FIRST_PASS_RETRY:
                self.logger.warning('Retry of first pass failed. Verify mesh integrity for best results.')
                self.fail('Failed to generate voronoi grid.')
            else:
                self.logger.warning(
                    'Second pass of generating the Voronoi grid failed. Verify mesh integrity for best results.')
                self.fail('Failed to generate voronoi grid.')

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Initialize the expensive stuff we can.
        self.logger.info('Initializing data extractor...')
        self._original_geometry = self._input_cogrid.ugrid
        ugrid = self._original_geometry
        self._extractor = UGrid2dDataExtractor(ugrid=ugrid)
        self.logger.info('Initializing boundary definition for input geometry...')
        boundary = UGridBoundaries(self._original_geometry)
        self._original_loops = boundary.get_loops()

        # Run the first Voronoi generation pass
        self.logger.info('Generating Voronoi diagram from input grid (pass 1 of 2)...')
        ugrid = self._generate_voronoi_grid(ugrid)

        # Retriangulate the cell centers of output UGrid, makes HEC-RAS like it better.
        self.logger.info('Retriangulating output for second pass...')
        ugrid = self._retriangulate_grid(ugrid)

        # Run the final Voronoi generation
        self.logger.info('Generating Voronoi diagram from input grid (pass 2 of 2)...')
        self._pass_number = self.FINAL_PASS
        ugrid = self._generate_voronoi_grid(ugrid)

        # Warn if any polygons in the grid have more than 8 edges.
        if self._cells_with_too_many_edges:
            bad_cells = '\n'.join([str(cell_id) for cell_id in self._cells_with_too_many_edges])
            self.logger.warning(
                f'The following cell(s) have more than eight edges:\n{bad_cells}\nThe maximum number of edges per '
                'cell in HEC-RAS is eight. Reduce nodal connectivity in the input UGrid for best results.'
            )

        # Build the UGrid to send back to XMS
        co_builder = UGridBuilder()
        co_builder.set_unconstrained()
        co_builder.set_ugrid(ugrid)
        cogrid = co_builder.build_grid()
        self.set_output_grid(cogrid, arguments[ARG_VORONOI_GRID])


def clip_output_to_boundary(outer, inner, regions, vertices, pt_map):
    """Clip finite polygon regions to the original boundary of the input mesh.

    Args:
        outer (List[Tuple[float, float]]): The outer polygon loop.
        inner (List[List[Tuple[float, float]]]): The inner polygon loops.
        regions (list): The finite polygon regions of the Voronoi diagram.
        vertices (np.ndarray): The point x,y,z locations.
        pt_map (Map[Tuple[float, float], int]): Mapping from point to index.

    Returns:
        list: The output grid cell stream.
    """
    boundaries = Polygon(outer, inner)  # Intersect all cells with the boundaries
    cellstream = []
    for region in regions:
        polygon = vertices[region]
        polyshape = Polygon(polygon)
        polyshape = polyshape.intersection(boundaries)
        if isinstance(polyshape, MultiPolygon):
            for polyring in polyshape.geoms:
                cellstream_points = [  # Don't repeat last node
                    pt_map.setdefault(p, len(pt_map)) for p in polyring.exterior.coords[:-1]
                ]
                num_new_pts = len(cellstream_points)
                cellstream_points.reverse()  # CCW
                cellstream.extend([7, len(cellstream_points)])  # XMU_POLYGON = 7
                cellstream.extend(cellstream_points)
        elif isinstance(polyshape, Polygon):
            cellstream_points = [  # Don't repeat last node
                pt_map.setdefault(p, len(pt_map)) for p in polyshape.exterior.coords[:-1]
            ]
            num_new_pts = len(cellstream_points)
            cellstream_points.reverse()  # CCW
            cellstream.extend([7, num_new_pts])  # XMU_POLYGON = 7
            cellstream.extend(cellstream_points)
        else:
            raise RuntimeError('Region intersection must be a polygon.')
    return cellstream
