"""Class for generating data sets from EWN polygons."""
__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import os
import uuid

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import GridType
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.geometry import geometry as xmg
from xms.grid.geometry.tri_search import TriSearch
import xms.mesher
from xms.mesher import meshing

# 4. Local modules


class EwnDatasetGenerator:
    """Class for generating data sets from EWN polygons."""
    INSIDE_OR_ON = [0, 1]
    NO_TRIANGLE_IDX = -1

    def __init__(self, polygon_data, cogrid, ugrid_uuid, initial_values, dset_location='points'):
        """Initializes the helper class.

        Args:
            polygon_data (:obj:`list[dict]`): Holds the outside polygon points and polygon attributes
            cogrid (:obj:`xms.grid.ugrid.UGrid`): 2d Unstructured grid
            ugrid_uuid (:obj:`str`): UUID of the target geometry
            initial_values (:obj:`list[float]`): The initial roughness data set values
            dset_location (:obj:`str`): Location of the dataset
        """
        self._logger = logging.getLogger('xms.ewn')
        self._polygon_data = polygon_data
        self._cogrid = cogrid
        self._ugrid_uuid = ugrid_uuid
        self._initial_values = initial_values
        self._dset_location = dset_location
        self._extents = []
        self._tri_searches = []
        self.output_dset = None
        self.out_filenames = None  # for testing

    def _get_out_filename(self):
        """Either get a hard-coded output filename (testing) or a randomly generated one in XMS temp."""
        if self.out_filenames:
            return self.out_filenames.pop()
        else:  # pragma: no cover
            temp_dir = os.environ.get('XMS_PYTHON_APP_TEMP_DIRECTORY', os.getcwd())
            return os.path.join(temp_dir, f'{uuid.uuid4()}.h5')

    def _preprocess_polygons(self):
        """Create bounding boxes and TriSearch objects for the EWN polygons."""
        self._logger.info('Computing extents of input EWN feature polygons...')
        for poly in self._polygon_data:  # Loop through all polygons of all input coverages in priority order.
            # Compute the bounding box extents of the polygon.
            outside_loop = np.array(poly['polygon_outside_pts_ccw'])
            min_x = np.amin(outside_loop[:, 0])
            max_x = np.amax(outside_loop[:, 0])
            min_y = np.amin(outside_loop[:, 1])
            max_y = np.amax(outside_loop[:, 1])
            self._extents.append((min_x, max_x, min_y, max_y))

            # Create a TriSearch for the polygon.
            meshing_inputs = [meshing.PolyInput(outside_polygon=outside_loop[:-1])]
            ug = xms.mesher.generate_mesh(polygon_inputs=meshing_inputs)
            cs = ug.cellstream
            tris = [idx for i, idx in enumerate(cs) if i % 5 != 0 and (i - 1) % 5 != 0]
            pts = [(p[0], p[1], p[2]) for p in ug.locations]
            self._tri_searches.append(TriSearch(points=pts, triangles=tris))

    def _create_dataset(self, dset_name):
        """Create the output data_objects Dataset to send back to XMS.

        Args:
            dset_name (:obj:`str`): Name to give the output roughness dataset
        """
        self._logger.info('Writing roughness to XMDF file...')
        writer = DatasetWriter(
            h5_filename=self._get_out_filename(),
            name=dset_name,
            geom_uuid=self._ugrid_uuid,
            location=self._dset_location
        )
        writer.write_xmdf_dataset(times=[0.0], data=[np.array(self._initial_values)])
        self.output_dset = writer

    def generate_roughness_dataset(self, cov_atts, dset_name):
        """Generate a roughness dataset from EWN polygons.

        Args:
            cov_atts (:obj:`dict`): Coverage UUID to its component data
            dset_name (:obj:`str`): Name to give the output roughness dataset
        """
        self._preprocess_polygons()
        self._logger.info('Applying EWN feature roughness values to target geometry. Please wait...')
        # Loop through all points of the target geometry.
        ug = self._cogrid.ugrid
        locations = ug.locations
        if self._dset_location == 'cells':
            locations = np.array([ug.get_cell_centroid(i)[1] for i in range(ug.cell_count)])

        if self._dset_location == 'points':
            if self._cogrid.grid_type in [GridType.rectilinear_2d, GridType.rectilinear_3d]:
                self._logger.error(
                    'The selected grid/mesh type does not support point based datasets. '
                    'Please change target to be cells.'
                )
                raise RuntimeError
        else:
            if self._cogrid.grid_type in [
                GridType.ugrid_2d,
            ]:
                self._logger.error(
                    'The selected grid/mesh type does not support cell based datasets.'
                    ' Please change target to be points.'
                )
                raise RuntimeError

        # for node_idx, loc in enumerate(self._ugrid.locations):
        for node_idx, loc in enumerate(locations):
            # Loop through all polygons of all input coverages in priority order.
            for poly_idx, poly in enumerate(self._polygon_data):
                # Trivially reject if point is outside bounds of the polygon.
                min_x, max_x, min_y, max_y = self._extents[poly_idx]
                if loc[0] < min_x or loc[0] > max_x or loc[1] < min_y or loc[1] > max_y:
                    continue

                # Check if target node is inside outer polygon.
                # outside_loop = poly['polygon_outside_pts_ccw']
                # if xmg.point_in_polygon_2d(outside_loop[:-1], loc) in self.INSIDE_OR_ON:
                if self._tri_searches[poly_idx].triangle_containing_point(loc) != self.NO_TRIANGLE_IDX:
                    # Loop through all the interior holes of this polygon, if the target point is inside one of them,
                    # do not use this polygon's roughness for the target point.
                    is_in_hole = False
                    for hole in poly['polygon_inside_pts_cw']:
                        if xmg.point_in_polygon_2d(hole[:-1], loc) in self.INSIDE_OR_ON:
                            is_in_hole = True
                            break  # Target point is within the outer polygon but inside one of its holes.
                    if not is_in_hole:
                        # Assign this polygon feature type's roughness value to the target point.
                        self._initial_values[node_idx] = poly['polygon_atts'].manning_n.item()
                        # type_idx = poly['polygon_atts'].classification.astype(np.int32).item()
                        # self._initial_values[node_idx] = cov_atts[poly['cov_uuid']].types.roughness[type_idx]
        self._create_dataset(dset_name)
