"""Module for the MapRoughnessRunner class."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['MapRoughnessRunner']

# 1. Standard Python modules
from dataclasses import dataclass
import math
from typing import cast, Iterable, Sequence

# 3. Aquaveo modules
from xms.components.component_builders.coverage_component_builder import CoverageComponentBuilder
from xms.coverage.polygons.polygon_orienteer import get_polygon_point_lists
from xms.data_objects.parameters import Coverage
from xms.datasets.dataset_reader import DatasetReader
from xms.gmi.data.generic_model import GenericModel
from xms.grid.geometry.geometry import point_in_polygon_2d
from xms.guipy.data.target_type import TargetType
from xms.guipy.dialogs.feedback_thread import ExpectedError, FeedbackThread
from xms.interp.interpolate import InterpLinear

# 4. Local modules
from xms.hydroas.components.coverage_component import CoverageComponent
from xms.hydroas.components.roughness_component import get_section
from xms.hydroas.data.coverage_data import CoverageData
from xms.hydroas.dmi.xms_data import XmsData

# 2. Third party modules


@dataclass
class PointData:
    """Data about a single point being added to the new coverage."""
    polygon_feature_id: int = -1  #: Feature ID of the polygon containing the point.
    polygon_component_id: int = -1  #: Component ID of the polygon containing the point.
    point_index: int = -1  #: Index in the UGrid where the point is.
    dataset_uuid: str = ''  #: UUID of the elevation dataset the point should be interpolated to.
    roughness: float = 0.0  #: The point's roughness value.
    elevation: float = 0.0  #: The point's interpolated elevation value.


class MapRoughnessRunner(FeedbackThread):
    """Roughness coverage mapping thread."""
    def __init__(self, data: XmsData, model: GenericModel, roughness_uuid: str):
        """
        Construct the worker.
        """
        super().__init__(create_query=False)
        self.display_text |= {
            'title': 'Applying roughness coverage',
            'working_prompt': 'Applying roughness coverage. Please wait...',
        }
        self.data = data
        self.model = model
        self.roughness_uuid = roughness_uuid
        self._points: list[PointData] = []
        self._datasets: dict[str, DatasetReader] = {}
        self._interpolators: dict[tuple[str, str], InterpLinear] = {}
        self._locations: Sequence[tuple[float, float, float]] = []
        self._point_indexes: list[int] = []
        self._point_component_ids: list[int] = []

    def _run(self):
        """Run the thread."""
        self.data.unlink(self.roughness_uuid)

        self._ensure_valid()
        self._locations = self.data.ugrid.ugrid.locations
        self._get_points()
        self._get_component_ids()
        self._get_roughness_and_elevation()
        self._set_up_interpolators()
        self._interpolate_elevations()
        self._build_component()

    def _ensure_valid(self):
        """
        Check for errors that will prevent mapping.

        Raises ExpectedError if anything is wrong, so the caller can just call it and continue.
        """
        if not self.data.ugrid:
            raise ExpectedError('Mapping a Bridge Pressure Elevation coverage requires a linked UGrid.')

        coverage, component = self.data.roughness_coverage

        if not coverage.polygons:
            raise ExpectedError('Coverage has no polygons. Create and assign polygons to map the coverage.')
        if not component.comp_to_xms[component.cov_uuid].get(TargetType.polygon, {}):
            raise ExpectedError(
                'Coverage has no assigned polygons. '
                'Assign feature attributes to at least one polygon to map the coverage.'
            )

    def _get_points(self):
        """
        Get all the points in the UGrid that should be put into the coverage.

        Points will have their index into the UGrid, and the feature ID of the polygon containing them.
        """
        roughness_coverage, _ = self.data.roughness_coverage
        locations = self.data.ugrid.ugrid.locations
        for feature_id, polygon in polygons(roughness_coverage):
            point_ids = _find_points_in_polygon(locations, polygon)
            for point_id in point_ids:
                point_data = PointData(polygon_feature_id=feature_id, point_index=point_id)
                self._points.append(point_data)

    def _get_component_ids(self):
        """
        Assign polygon component IDs to each point.

        The polygon component ID can be used to look up the attributes of the polygon that a point came from.
        """
        _, roughness_component = self.data.roughness_coverage
        for point_data in self._points:
            component_id = roughness_component.get_comp_id(TargetType.polygon, point_data.polygon_feature_id)
            point_data.polygon_component_id = component_id

    def _get_roughness_and_elevation(self):
        """Assign the roughness value and elevation dataset UUID to each point."""
        _, roughness_component = self.data.roughness_coverage
        data = roughness_component.data
        section = get_section()
        for point in self._points:
            values = data.feature_values(TargetType.polygon, point.polygon_component_id)
            section.restore_values(values)

            parameter = section.group('bridge').parameter('roughness')
            point.roughness = parameter.value

            parameter = section.group('bridge').parameter('elevation')
            point.dataset_uuid = parameter.value

    def _set_up_interpolators(self):
        """Set up interpolators for all the geometries."""
        for point in self._points:
            dataset = self.data.datasets[point.dataset_uuid]
            if not dataset:
                raise ExpectedError(f'Missing/undefined dataset for polygon with feature ID {point.polygon_feature_id}')
            self._set_up_interpolator(dataset.geom_uuid, dataset.location)

    def _set_up_interpolator(self, geom_uuid: str, location: str):
        """
        Set up an interpolator for a geometry.

        Args:
            geom_uuid: UUID of the geometry to set up the interpolator for.
            location: Where the data applies to. 'points' or 'cells'.
        """
        if (geom_uuid, location) in self._interpolators:
            return
        geometry = self.data.geometries[geom_uuid]
        if location == 'cells':
            pts = [geometry.ugrid.get_cell_centroid(cell)[1] for cell in range(geometry.ugrid.cell_count)]
        else:
            pts = geometry.ugrid.locations
        self._interpolators[(geom_uuid, location)] = InterpLinear(points=pts)

    def _interpolate_elevations(self):
        """Use interpolators to find the elevations of each point in self._points."""
        for point in self._points:
            dataset = self.data.datasets[point.dataset_uuid]
            interpolator = self._interpolators[(dataset.geom_uuid, dataset.location)]
            interpolator.scalars = dataset.values[0]
            location = tuple(self._locations[point.point_index])
            point.elevation = interpolator.interpolate_to_point(location)
            if math.isnan(point.elevation):
                point.elevation = -999.0

    def _build_component(self):
        """Initialize the coverage and component data."""
        section = self.model.copy().point_parameters
        bridge_group = section.group('1')
        bridge_group.is_active = True

        point_values = []

        for point in self._points:
            bridge_group.parameter('1').value = point.elevation
            if '2' in bridge_group.parameter_names:
                bridge_group.parameter('2').value = point.roughness
            self._point_indexes.append(point.point_index)
            point_values.append(section.extract_values())

        point_groups = ['1'] * len(point_values)

        coverage, _component = self.data.roughness_coverage
        name = f'{coverage.name} (applied)'

        builder = CoverageComponentBuilder(CoverageComponent, name, self.data.ugrid)
        data = cast(CoverageData, builder.data)
        data.generic_model = self.model
        self._point_component_ids = data.add_features(TargetType.point, point_values, point_groups)

        builder.add_nodes(self._point_indexes, self._point_component_ids)
        coverage, component, keywords = builder.build()

        self.data.add_coverage(coverage, component, keywords=keywords)
        self.data.link(coverage.uuid)


def _find_points_in_polygon(locations: Sequence[tuple[float, float, float]], polygon: list[tuple[float, float, float]]):
    """
    Find all the points in a list of locations that are inside or on a polygon.

    This is a 2D operation. The third dimension is ignored.

    Args:
        locations: Sequence of tuples of (x, y, z) locations from a UGrid.
        polygon: Sequence of tuples of (x, y, z) locations defining a polygon.

    Returns:
        Elements of `locations` that are also within the bounds of the polygon.
    """
    points_in_polygon = set()

    for index, point in enumerate(locations):
        if point_in_polygon_2d(polygon, point) >= 0:
            points_in_polygon.add(index)

    return points_in_polygon


def polygons(coverage: Coverage) -> Iterable[tuple[int, list[tuple[float, float, float]]]]:
    for polygon in coverage.polygons:
        points = get_polygon_point_lists(polygon)[0]
        points.pop()  # Above makes last point a duplicate of first, but we're providing to code that doesn't like that.
        yield polygon.id, points
