"""FillHolesInUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import time
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.grid import GridType
from xms.constraint.ugrid_builder import UGridBuilder
from xms.grid.ugrid import UGrid as XmUGrid
import xms.mesher
from xms.mesher import meshing
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.coverage.grid_cell_to_polygon_coverage_builder import GridCellToPolygonCoverageBuilder

ARG_INPUT_UGRID = 0
ARG_OUTPUT_GRID = 1


class FillHolesInUGridTool(Tool):
    """Tool to fill in holes in grids."""

    def __init__(self, name='Fill Holes in UGrid'):
        """Initializes the class."""
        super().__init__(name=name)
        self._input_ugrid_name = ''
        self._output_grid_name = ''
        self._input_grid = None
        self._input_ugrid = None
        self._grid_holes = []
        self._output_ugrid = None
        self._output_grid = None
        self._output_grid_uuid = ''
        self._interval = 1 * 10**9  # 1 seconds in nanoseconds

        self._force_ugrid = True
        self._geom_txt = 'UGrid'

    def _validate_input_grid(self, argument):
        """Validate grid is specified and 2D.

        Args:
            argument (GridArgument): The grid argument.
            node_based (bool): True if source is UGrid points, False if cell centers

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        key = argument.name
        self._input_grid = self.get_input_grid(argument.text_value)
        if self._input_grid:
            # If we have a grid, perform all the checks
            grid_errors = []
            self._input_ugrid = self._input_grid.ugrid  # Only unwrap this once.
            if self._input_grid.grid_type in [GridType.ugrid_3d, GridType.rectilinear_3d, GridType.quadtree_3d]:
                grid_errors.append('Must have all 2D cells.')
            error = '\n'.join(grid_errors)
            if error:
                errors[key] = error
        return errors

    def _fill_ugrid_holes(self):
        # Build Coverage to Get Polys
        ug_builder = UGridBuilder()
        ug_builder.set_is_2d()
        ug_builder.set_ugrid(self._input_ugrid)
        grid = ug_builder.build_grid()
        poly_builder = GridCellToPolygonCoverageBuilder(co_grid=grid,
                                                        dataset_values=[0] * self._input_ugrid.cell_count,
                                                        wkt=None,
                                                        coverage_name='temp',
                                                        logger=self.logger)
        out_polys = poly_builder.find_polygons()

        if len(out_polys) == 0:
            self.fail('UGrid does not contain any cells. Aborting.')

        num_holes = 0
        for poly_list in out_polys[0]:
            num_holes += len(poly_list) - 1
        self.logger.info(f'Found {num_holes} holes.')

        # fill holes in the grid using earcut triangulation
        last_report = time.monotonic_ns()
        ug = self._input_grid.ugrid
        grid_locs = ug.locations
        new_cell_stream = list(ug.cellstream)
        count = 1
        cell_type = XmUGrid.cell_type_enum.TRIANGLE
        for poly_list in out_polys[0]:
            for poly in poly_list[1:]:
                if time.monotonic_ns() - last_report >= self._interval:
                    self.logger.info(f'Processing polygon {count} of {num_holes}.')
                    last_report = time.monotonic_ns()

                pp = poly.copy()
                pp.pop()
                count += 1
                locs = [grid_locs[p] for p in pp]  # Convert poly to points
                loc_to_idx = {(grid_locs[p][0], grid_locs[p][1]): p for p in pp}
                poly_input = meshing.PolyInput(outside_polygon=locs, generate_interior_points=False)
                ug = xms.mesher.generate_mesh(polygon_inputs=[poly_input])
                cs = ug.cellstream
                ug_locs = ug.locations
                ug_pt_idx = [loc_to_idx[(ug_locs[i][0], ug_locs[i][1])] for i in range(len(ug_locs))]
                idx = 0
                while idx < len(cs):
                    idxs = [cell_type, 3, ug_pt_idx[cs[idx + 2]], ug_pt_idx[cs[idx + 3]], ug_pt_idx[cs[idx + 4]]]
                    new_cell_stream.extend(idxs)
                    idx += 5

        # Setup output grid
        new_ug = XmUGrid(grid_locs, new_cell_stream)
        co_builder = UGridBuilder()
        co_builder.set_is_2d()
        co_builder.set_ugrid(new_ug)
        self._output_grid = co_builder.build_grid()
        self._output_grid_uuid = str(uuid.uuid4())
        self._output_grid.uuid = self._output_grid_uuid

    def _set_outputs(self):
        """Set outputs from the tool."""
        self.set_output_grid(self._output_grid, self._args[ARG_OUTPUT_GRID], force_ugrid=self._force_ugrid)
        pass

    def _get_inputs(self):
        """Get the user inputs."""
        self.logger.info('Retrieving input data from SMS...')
        # If user didn't give an output grid name, use the name of the input UGrid.
        user_input = self._args[ARG_OUTPUT_GRID].text_value
        self._input_ugrid_name = self._args[ARG_INPUT_UGRID].text_value
        input_ugrid_base_name = os.path.basename(self._input_ugrid_name)
        self._output_grid_name = os.path.basename(user_input) if user_input else input_ugrid_base_name
        self._args[ARG_OUTPUT_GRID].value = self._output_grid_name

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_ugrid', description='2D UGrid', io_direction=IoDirection.INPUT),
            self.grid_argument(name='grid', description='New UGrid Name',
                               optional=True, io_direction=IoDirection.OUTPUT),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        grid_arg = arguments[ARG_INPUT_UGRID]
        errors = {}
        if grid_arg.value is not None:  # Case of no selected UGrid already handled in base
            errors = self._validate_input_grid(grid_arg)
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self.logger.info(f'Filling holes in {self._geom_txt}')
        self._args = arguments
        self._get_inputs()
        self._fill_ugrid_holes()
        self._set_outputs()
