"""Class for updating adh friction."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import QuadtreeGridBuilder
from xms.constraint.quadtree import quadtree_smoother, refinement
from xms.constraint.quadtree.mapper_polygon import MapperPolygon
from xms.tool.algorithms.geometry.geometry import run_parallel_points_in_polygon

# 4. Local modules
import xms.ewn.data.ewn_cov_data_consts as consts
from xms.ewn.tools.runners import runner_util


class QuadTreeRefiner:
    """The QuadTree refiner class."""
    def __init__(self, ugrid, ewn_cov):
        """Initializes the helper class.

        Args:
            ugrid: The UGrid to refine
            ewn_cov: The ewn coverage
        """
        self._logger = logging.getLogger('xms.ewn')

        self._ugrid = ugrid
        self._cell_getter = self._ugrid.quad_grid_cells
        self._refinement = None
        self._ewn_cov = ewn_cov
        self._ewn_comp = ewn_cov[0][1]
        self._ewn_geom = ewn_cov[0][0]

        self.out_ugrid = None
        self._cell_id_map = None

        self._poly_data = None

    def _get_ugrid_points_in_polygon(self, poly):
        """Get a list of ugrid point indexes that are inside the polygon.

        Args:
            poly (list): list of x, y, z coordinates

        Returns:
            (list): list of ugrid point indexes
        """
        # find all ugrid points in the transect_offset polygon
        poly = np.asarray([(p[0], p[1]) for p in poly])
        locs_2d = np.asarray([(p[0], p[1]) for p in self._ugrid.locations])
        result = run_parallel_points_in_polygon(locs_2d, poly)
        inside_points = [i for i, flag in enumerate(result) if flag]
        return inside_points

    def _get_cells_in_polygon(self, grid, poly, holes=None):
        ug = grid.ugrid
        centroid_locations = [ug.get_cell_centroid(i)[1] for i in range(ug.cell_count)]
        locs_2d = np.asarray([(p[0], p[1]) for p in centroid_locations])
        cell_ids_in_poly = None
        np_poly = np.asarray([(p[0], p[1]) for p in poly])
        cell_ids_in_poly = set([i for i, flag in enumerate(run_parallel_points_in_polygon(locs_2d, np_poly)) if flag])
        for in_poly in holes:
            np_poly = np.asarray([(p[0], p[1]) for p in in_poly])
            pt_ids_inside_poly = set([i for i, f in enumerate(run_parallel_points_in_polygon(locs_2d, np_poly)) if f])
            cell_ids_in_poly = cell_ids_in_poly.difference(pt_ids_inside_poly)
        return cell_ids_in_poly

    def refine_quadtree(self):
        """Refine a quadtree."""
        # User Snapper to get Cell IDs In polygons
        self._poly_data = sorted(
            runner_util.get_ewn_polygon_input(self._ewn_geom, self._ewn_comp, inside_polys=True),
            key=lambda k: k['quadtree_refinement_length'],
            reverse=True
        )

        if len(self._poly_data) <= 0:
            self._logger.error('No EWN Polygons defined...')
            raise RuntimeError

        mapper_polygons = []
        for poly in self._poly_data:
            # Create Mapper Polygons
            mapper_poly = MapperPolygon(
                outside_polygon=poly.get('polygon_outside_pts_ccw', []),
                inside_polygons=poly.get('polygon_inside_pts_cw', []),
                refine_size=poly['quadtree_refinement_length'],
                to_layer=-1,
                from_layer=-1,
            )
            mapper_polygons.append(mapper_poly)

        cells, cell_map = refinement.refine_quadtree_gridcells(self._ugrid, [], [], mapper_polygons)

        builder = QuadtreeGridBuilder()
        builder.origin = self._ugrid.origin
        builder.angle = self._ugrid.angle
        builder.numbering = self._ugrid.numbering
        builder.orientation = self._ugrid.orientation
        builder.locations_x = self._ugrid.locations_x
        builder.locations_y = self._ugrid.locations_y
        builder.quad_grid_cells = cells

        builder.quad_grid_cells = cells
        out_ugrid = builder.build_grid()
        old_cell_elevations = list(x for x in self._ugrid.cell_elevations)
        out_ugrid.cell_elevations = list((old_cell_elevations[x] for x in cell_map))

        new_cells = out_ugrid.quad_grid_cells

        # Smooth new cells
        smooth_cells, smooth_cell_map = quadtree_smoother.refine_to_quadtree(
            new_cells, smooth_between=True, smooth_corners=True
        )

        builder.quad_grid_cells = smooth_cells
        smoothed_out_ugrid = builder.build_grid()
        old_cell_elevations = list(x for x in out_ugrid.cell_elevations)
        new_cell_elevations = list((old_cell_elevations[x] for x in smooth_cell_map))

        # Find cells in polygons
        for poly in self._poly_data:
            cell_ids_in_poly = self._get_cells_in_polygon(
                smoothed_out_ugrid, poly['polygon_outside_pts_ccw'], poly['polygon_inside_pts_cw']
            )
            elevations = [poly['elevation'] if i in cell_ids_in_poly else None for i in range(len(new_cell_elevations))]
            e_map = list(zip(new_cell_elevations, elevations))
            if poly['elevation_method'] == consts.ELEVATION_METHOD_CONSTANT:
                new_cell_elevations = [
                    e_map[i][1] if e_map[i][1] is not None else e_map[i][0] for i in range(len(new_cell_elevations))
                ]
            else:
                # Add elevation to cell elevation
                new_cell_elevations = [
                    e_map[i][0] + e_map[i][1] if e_map[i][1] is not None else e_map[i][0]
                    for i in range(len(new_cell_elevations))
                ]

        smoothed_out_ugrid.cell_elevations = new_cell_elevations
        self.out_ugrid = smoothed_out_ugrid
