"""MapActivityToUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from geopandas import GeoDataFrame
import numpy as np

# 3. Aquaveo modules
from xms.constraint import copy_grid
from xms.tool_core import ALLOW_ONLY_COVERAGE_TYPE, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.geometry.geometry import run_parallel_points_in_polygon
from xms.tool.utilities.coverage_conversion import parallel_polygon_perimeters

ARG_INPUT_GRID = 0
ARG_INPUT_ACTIVITY_COVERAGE = 1
ARG_INPUT_NAME = 2


class MapActivityToUGridTool(Tool):
    """Map activity from a coverage to a new ugrid."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Map Activity to UGrid')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        activity_filter = {ALLOW_ONLY_COVERAGE_TYPE: 'ACTIVITY_CLASSIFICATION'}
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid'),
            self.coverage_argument(name='input_activity_coverage', description='Activity coverage',
                                   optional=True, filters=activity_filter),
            self.string_argument(name='ugrid_name', description='Name of the output ugrid', value='',
                                 optional=True),
            self.grid_argument(name='output_ugrid', description='The ugrid', hide=True, optional=True,
                               io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate primary and secondary grids are specified and 2D
        self._validate_input_grid(errors, arguments[ARG_INPUT_GRID])

        return errors

    def _validate_input_grid(self, errors, argument):
        """Validate grids are specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        grid = self.get_input_grid(argument.text_value)
        if not grid:
            errors[key] = 'Could not read grid.'
        # No 3D grids right now
        # else:
        #     if not grid.check_all_cells_2d():
        #         errors[key] = 'Must have all 2D cells.'

    def _set_activity_from_coverage(self, activity_cov: GeoDataFrame):
        """Map activity coverage polygons to the Quadtree.

        Args:
            activity_cov (GeoDataFrame): An activity coverage.
        """
        polygons = activity_cov[activity_cov['geometry_types'] == 'Polygon']
        atts = activity_cov.attrs['attributes'] if activity_cov is not None else None
        all_polys = parallel_polygon_perimeters(activity_cov)
        centroids = []
        for cell_idx in range(self._input_xm_grid.cell_count):
            centroids.append(self._input_xm_grid.get_cell_centroid(cell_idx)[1])
        centroid_locs = np.asarray([(p[0], p[1]) for p in centroids])
        activity = {}
        for poly_id, val in zip(atts['Id'], atts['Activity']):
            activity[poly_id] = val != 0
        for poly in polygons.itertuples():
            if poly.id in activity and not activity[poly.id]:
                in_poly = list(run_parallel_points_in_polygon(centroid_locs, all_polys[poly.id][0]))
                for hole in all_polys[poly.id][1:]:
                    #  if it is in the hole, do not add this element
                    in_hole = list(run_parallel_points_in_polygon(centroid_locs, hole))
                    in_poly = [in_poly[i] if in_hole[i] == 0 else 0 for i in range(len(in_poly))]
                for cell_idx in range(len(centroid_locs)):
                    if in_poly[cell_idx] == 1:
                        self._cell_activity[cell_idx] = False  # mark as inactive

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._input_ugrid = self.get_input_grid(arguments[ARG_INPUT_GRID].text_value)
        activity_cov = None
        if arguments[ARG_INPUT_ACTIVITY_COVERAGE].text_value:
            activity_cov = self.get_input_coverage(arguments[ARG_INPUT_ACTIVITY_COVERAGE].value)
        self._input_xm_grid = self._input_ugrid.ugrid
        cell_count = self._input_xm_grid.cell_count

        self._cell_activity = [True for _ in range(cell_count)]

        if activity_cov is not None and not activity_cov.empty:  # Activity defined on a coverage (13.1 and earlier)
            self._set_activity_from_coverage(activity_cov)

        # Make a copy of the input grid
        self.logger.info('Creating output grid...')
        new_grid = copy_grid(self._input_ugrid)
        new_grid.model_on_off_cells = self._cell_activity
        self.set_output_grid(new_grid, arguments[ARG_INPUT_NAME], None)
