"""ExtractSubgridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import uuid

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import GridType
from xms.constraint import UGrid2d
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.constraint.ugrid_builder import UGridBuilder
from xms.grid.ugrid import UGrid
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.geometry.geometry import run_parallel_points_in_polygon
from xms.tool.utilities.coverage_conversion import parallel_polygon_perimeters

DATASET_TYPE_UNDEFINED = -1
DATASET_TYPE_CELLS = 0
DATASET_TYPE_POINTS = 1


class ExtractSubgridTool(Tool):
    """Tool to extract a submesh from a mesh2d."""
    ARG_INPUT_GRID = 0
    ARG_INPUT_COVERAGE = 1
    ARG_OUTPUT_GRID = 2

    def __init__(self, name='Extract Subgrid'):
        """Initializes the class."""
        super().__init__(name=name)
        self._args = []
        self._sel_cov = None
        self._orig_ug = None
        self._orig_sel_elem_ids = set()
        self._dict_elems = dict()  # map orig elem IDs to new elem Ids
        self._dict_nodes = dict()  # map orig node IDs to new node IDs
        self._force_ugrid = True
        self._geom_txt = 'UGrid'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_grid', description='Grid', io_direction=IoDirection.INPUT),
            self.coverage_argument(name='input_coverage', description='Subgrid boundary coverage'),
            self.grid_argument(name='grid', description='Grid name', optional=True,
                               io_direction=IoDirection.OUTPUT),
        ]
        return arguments

    def _get_selected_elements(self):
        """Called to determine if arguments are valid.

        """
        all_polys = parallel_polygon_perimeters(self._sel_cov)
        centroids = []
        for cell_idx in range(self._orig_ug.cell_count):
            centroids.append(self._orig_ug.get_cell_centroid(cell_idx)[1])
        centroid_locs = np.asarray([(p[0], p[1]) for p in centroids])
        for poly in all_polys.values():
            in_poly = list(run_parallel_points_in_polygon(centroid_locs, poly[0]))
            for hole in poly[1:]:
                #  if it is in the hole, do not add this element
                in_hole = list(run_parallel_points_in_polygon(centroid_locs, hole))
                in_poly = [in_poly[i] if in_hole[i] == 0 else 0 for i in range(len(in_poly))]
            self._orig_sel_elem_ids.update({i for i in range(len(in_poly)) if in_poly[i] == 1})
        if not self._orig_sel_elem_ids:
            raise RuntimeError(f'No {self._geom_txt} cells/elements are inside of coverage polygons. Aborting.')

    def _renumber_extracted(self):
        """Called to determine if arguments are valid.

        """
        # get all of the original nodes
        orig_nodes = set()
        for cell_idx in self._orig_sel_elem_ids:
            node_ids = self._orig_ug.get_cell_points(cell_idx)
            for node_id in node_ids:
                orig_nodes.add(node_id)

        # renumber the nodes and create a map
        cur_node_id = 0
        orig_ordered_nodes = sorted(orig_nodes)
        for orig_node in orig_ordered_nodes:
            self._dict_nodes[orig_node] = cur_node_id
            cur_node_id = cur_node_id + 1

        # renumber the elements and create a map
        cur_elem_id = 0
        orig_ordered_elems = sorted(self._orig_sel_elem_ids)
        for orig_elem in orig_ordered_elems:
            self._dict_elems[orig_elem] = cur_elem_id
            cur_elem_id = cur_elem_id + 1

    def _create_extracted_grid(self):
        """Called to determine if arguments are valid.

        """
        # create the locations list
        extracted_locations = []
        locations = list(self._orig_ug.locations)
        for orig_node in sorted(self._dict_nodes.keys()):
            extracted_locations.append(locations[orig_node])

        # create the new cellstream
        extracted_cellstream = []
        for orig_cell in sorted(self._dict_elems.keys()):
            orig_cell_cellstream = list(self._orig_ug.get_cell_cellstream(orig_cell)[1])
            extracted_cellstream.append(orig_cell_cellstream[0])
            extracted_cellstream.append(orig_cell_cellstream[1])
            for node in range(2, len(orig_cell_cellstream)):
                extracted_cellstream.append(self._dict_nodes[orig_cell_cellstream[node]])

        if self._force_ugrid is True:
            self.logger.info('Building Extracted UGrid...')
            ug = UGrid(extracted_locations, extracted_cellstream)
            co_builder = UGridBuilder()
            orig_grid = self.get_input_grid(self._args[self.ARG_INPUT_GRID].text_value)
            co_builder.set_is_2d()
            if orig_grid.grid_type == GridType.unconstrained:
                co_builder.set_unconstrained()
            co_builder.set_ugrid(ug)
            self._extracted_grid = co_builder.build_grid()
        else:
            self.logger.info('Building Extracted Mesh...')
            ug = UGrid(extracted_locations, extracted_cellstream)
            self._extracted_grid = UGrid2d(ug)

        self._extracted_grid_uuid = str(uuid.uuid4())
        self._extracted_grid.uuid = self._extracted_grid_uuid

        self.set_output_grid(self._extracted_grid, self._args[self.ARG_OUTPUT_GRID], force_ugrid=self._force_ugrid)

    def _create_extracted_datasets(self):
        """Called to determine if arguments are valid.

        """
        # get the datasets to be copied
        orig_datasets = self.get_grid_datasets(self._args[self.ARG_INPUT_GRID].value)
        for dset_name in orig_datasets:
            reader = self.get_input_dataset(dset_name)

            dataset_type = DATASET_TYPE_UNDEFINED
            if reader.num_values == self._orig_ug.point_count:
                dataset_type = DATASET_TYPE_POINTS
                reader.activity_calculator = CellToPointActivityCalculator(self._orig_ug)
            elif reader.num_values == self._orig_ug.cell_count:
                dataset_type = DATASET_TYPE_CELLS

            if reader.num_components == 1:
                values = reader.timestep_with_activity(0)[0]
                elevs = np.fromiter(self._orig_grid.point_elevations, dtype=values.dtype)
                if np.array_equal(values, elevs):
                    # don't make a copy of the elevation dataset
                    continue

            new_dataset_name = dset_name.replace(f'{self._args[self.ARG_INPUT_GRID].text_value}/', '')
            self.logger.info(f'Copying "{dset_name}" to sub{self._geom_txt}...')

            writer = self.get_output_dataset_writer(
                name=new_dataset_name,
                geom_uuid=self._extracted_grid_uuid,
                num_components=reader.num_components,
                ref_time=reader.ref_time,
                time_units=reader.time_units,
                null_value=reader.null_value,
                location='points' if dataset_type == DATASET_TYPE_POINTS else 'cells'
            )

            time_count = len(reader.times)
            for tsidx in range(time_count):
                orig_data, orig_activity = reader.timestep_with_activity(tsidx, nan_activity=True)
                orig_data = list(orig_data)
                extracted_data = []
                extracted_activity = None
                use_activity = False
                if orig_activity is not None:
                    use_activity = True
                    orig_activity = list(orig_activity)
                    extracted_activity = []

                if dataset_type == DATASET_TYPE_POINTS:
                    for orig_node in sorted(self._dict_nodes.keys()):
                        extracted_data.append(orig_data[orig_node])
                        if use_activity:
                            extracted_activity.append(orig_activity[orig_node])
                elif dataset_type == DATASET_TYPE_CELLS:
                    for orig_elem in sorted(self._dict_elems.keys()):
                        extracted_data.append(orig_data[orig_elem])
                        if use_activity:
                            extracted_activity.append(orig_activity[orig_elem])
                writer.append_timestep(reader.times[tsidx], extracted_data, extracted_activity)

            writer.appending_finished()
            self.set_output_dataset(writer)

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._args = arguments
        if self._args[self.ARG_INPUT_COVERAGE].text_value:
            self._sel_cov = self.get_input_coverage(self._args[self.ARG_INPUT_COVERAGE].value)

        self._orig_grid = self.get_input_grid(self._args[self.ARG_INPUT_GRID].text_value)
        self._orig_ug = self._orig_grid.ugrid

        self._get_selected_elements()
        self._renumber_extracted()
        self._create_extracted_grid()
        self._create_extracted_datasets()
