"""Module for CoverageBuilder class."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"
__all__ = ['CoverageComponentBuilder', 'UNASSIGNED_TYPE', 'MULTIPLE_TYPES']

# 1. Standard Python modules
from functools import cached_property
from typing import Any, cast, Optional, Sequence, Type, TypeAlias, TypeVar

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import UGrid2d
from xms.coverage.coverage_builder import Arc, CoverageBuilder, Polygon
from xms.coverage.grid.grid_cell_to_polygon_coverage_builder import GridCellToPolygonCoverageBuilder
from xms.data_objects.parameters import Coverage, Projection
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.components.bases.visible_coverage_component_base import VisibleCoverageComponentBase
from xms.components.bases.visible_coverage_component_base_data import VisibleCoverageComponentBaseData
from xms.components.component_builders.main_file_maker import make_main_file
from xms.components.display.display_options_helper import DisplayOptionsHelper, MULTIPLE_TYPES, UNASSIGNED_TYPE

ComponentType = TypeVar('ComponentType', bound=VisibleCoverageComponentBase)
CoverageComponentAndKeywords: TypeAlias = tuple[Coverage, ComponentType, Any]
CoverageAndComponent: TypeAlias = tuple[Coverage, ComponentType]


class CoverageComponentBuilder:
    """
    Class for building a coverage using node IDs in a UGrid rather than locations.

    This should be used by passing the type (not an instance) of the component you want to build. The component's data
    manager will be exposed on the builder's `data` attribute so that you can add feature data to it. When you're done
    adding features, call `self.build()` to get the things you can put in `Query`.

    This will retrieve the component's feature list when you call the `build()` method. If generating that list
    requires data from the data manager, it should be inserted before calling `build()`, or else the component's display
    .json files will be wrong.
    """
    def __init__(
        self,
        component_type: Type[ComponentType],
        name: str,
        ugrid: Optional[UGrid2d] = None,
        projection: Optional[Projection] = None
    ):
        """
        Initialize the builder.

        Args:
            component_type: Type of the component to build. This should be a class (not an instance) that inherits from
                VisibleCoverageComponentBase.
            name: Name to assign the coverage. This will appear in the project tree.
            ugrid: Reference geometry for building features from.
            projection: Projection to assign the newly built coverage.
        """
        self._builder = CoverageBuilder(name=name, projection=projection)

        self._point_feature_ids: list[int] = []
        self._point_component_ids: list[int] = []
        self._arc_feature_ids: list[int] = []
        self._arc_component_ids: list[int] = []
        self._polygon_feature_ids: list[int] = []
        self._polygon_component_ids: list[int] = []

        self._coverage: Optional[Coverage] = None
        self._name: str = name
        self._projection = projection

        self._used_point_indexes: set[int] = set()

        self._material_coverage: Optional[Coverage] = None
        main_file = make_main_file(component_type)
        self._component: VisibleCoverageComponentBase = component_type(main_file)

        self._ugrid = ugrid
        self._display_labels = True

    @property
    def data(self) -> VisibleCoverageComponentBaseData:
        """The component's data manager."""
        return self._component.data

    def add_nodes(self, node_indexes: list[int], component_ids: list[int]):
        """
        Add points to the coverage based on node indexes in the UGrid.

        Args:
            node_indexes: Indexes of the node to add. These are their indexes in the UGrid's point list.
            component_ids: Component IDs to assign the points.
        """
        for node_index, component_id in zip(node_indexes, component_ids, strict=True):
            self.add_node(node_index, component_id)

    def add_node(self, node_index: int, component_id: int) -> int:
        """
        Add a point to the coverage based on a node index in the UGrid.

        Args:
            node_index: Index of the node to add. This is its index in the UGrid's point list.
            component_id: Component ID to assign the point.

        Returns:
            The feature ID assigned to the point.
        """
        node_index, component_id = _coerce(node_index, component_id)

        if node_index in self._used_point_indexes:
            raise AssertionError(f'Point at index {node_index} added twice')
        self._used_point_indexes.add(node_index)

        x, y, z = self._locations[node_index]
        feature_id = self._builder.add_point((x, y, z))

        self._point_feature_ids.append(feature_id)
        self._point_component_ids.append(component_id)

        return feature_id

    def add_point(self, x: float, y: float, z: float, component_id: int):
        """
        Add a point to the coverage based on a location.

        Args:
            x: The x coordinate of the point.
            y: The y coordinate of the point.
            z: The z coordinate of the point.
            component_id: Component ID to assign the point.
        """
        component_id, = _coerce(component_id)
        self._point_component_ids.append(component_id)

        feature_id = self._builder.add_point((x, y, z))
        self._point_feature_ids.append(feature_id)

    def add_node_string(self, node_indexes: list[int], component_id: int):
        """
        Add an arc to the coverage based on node indexes in the UGrid.

        Args:
            node_indexes: Indexes of nodes defining the arc.
            component_id: Component ID to assign the arc.
        """
        arc = [self._locations[index] for index in node_indexes]
        feature_id = self._builder.add_arc(arc)
        self._arc_feature_ids.append(feature_id)
        self._arc_component_ids.append(component_id)

    def add_point_string(self, points_xyz: Arc, component_id: int):
        """
        Add an arc to the coverage based on x, y, z point coordinates.

        Args:
            points_xyz: List of (x, y, z) coordinates defining the arc.
            component_id: Component ID to assign the arc.
        """
        component_id, = _coerce(component_id)
        self._arc_component_ids.append(component_id)

        feature_id = self._builder.add_arc(points_xyz)
        self._arc_feature_ids.append(feature_id)

    def add_cell_string(self, cell_indexes: list[int], component_id: int):
        """
        Add an arc to the coverage based on cell indexes in the UGrid.

        Effectively the same as add_node_string, except it targets cell centers.

        Args:
            cell_indexes: Indexes of cells defining the arc.
            component_id: Component ID to assign the arc.
        """
        arc = [self._cell_locations[index] for index in cell_indexes]
        feature_id = self._builder.add_arc(arc)

        self._arc_feature_ids.append(feature_id)
        self._arc_component_ids.append(component_id)

    def add_polygon(self, polygon: Polygon, component_id: int):
        """
        Add a polygon to the coverage based on locations.

        Args:
            polygon: The polygon to add.
            component_id: Component ID to assign the point.
        """
        component_id, = _coerce(component_id)
        self._polygon_component_ids.append(component_id)

        feature_id = self._builder.add_polygon(polygon)
        self._polygon_feature_ids.append(feature_id)

    def add_polygons(
        self,
        dataset_values: Sequence[int],
        dataset_value_to_component_id: dict[int, int],
        null_value: Optional[int] = None
    ):
        """
        Add polygons to the coverage.

        Args:
            dataset_values: Values identifying which cells in the UGrid belong together. Each cell in the UGrid will be
                assigned the value from the corresponding index in dataset_values, then adjacent cells with the same
                value will be merged into a single polygon. Cells with a dataset value of 0 will not be assigned to any
                polygon. Disjoint groups of cells with the same value will result in disjoint polygons with the same
                component ID (but unique feature IDs).
            dataset_value_to_component_id: Dictionary mapping values in dataset_values to the component ID that should
                be assigned to any polygon with that value. This can have extra keys that are not present in
                dataset_values; they will be ignored.
            null_value: Value in dataset_values to treat as not part of any polygon. If None, all cells will be part of
                a polygon. Otherwise, any cell with this dataset value will be excluded.
        """
        have_features = self._point_feature_ids or self._arc_feature_ids or self._polygon_feature_ids or self._coverage
        if have_features:
            # This works by creating a whole new coverage with the polygons in it. It's possible in theory to build the
            # polygons and insert them into the existing coverage, but GridCellToPolygonCoverageBuilder's API doesn't
            # support that and I'm not sure the best way to implement it. An earlier attempt in xmsgmi tried pulling the
            # polygons out of the coverage the builder builds, but doing that lost the holes.
            #
            # Chances are we'll never actually hit this in practice, but if we do, someone will probably appreciate not
            # having to spend an afternoon figuring out where their features disappeared to.
            raise AssertionError('add_polygons() called while features exist.')

        dataset_values = cast(list[int], dataset_values)  # Sequence and list are close enough, but typing complains.

        polygon_merger = GridCellToPolygonCoverageBuilder(
            co_grid=self._ugrid,
            dataset_values=dataset_values,
            projection=self._projection,
            coverage_name=self._name,
            null_value=null_value,
        )
        self._coverage = polygon_merger.create_polygons_and_build_coverage()

        feature_id_to_dataset_value = {}
        for dataset_value in polygon_merger.dataset_polygon_ids:
            for feature_id in polygon_merger.dataset_polygon_ids[dataset_value]:
                feature_id_to_dataset_value[feature_id] = dataset_value

        for polygon in self._coverage.polygons:
            feature_id = polygon.id
            dataset_value = feature_id_to_dataset_value[feature_id]
            component_id = dataset_value_to_component_id[dataset_value]
            self._polygon_feature_ids.append(feature_id)
            self._polygon_component_ids.append(component_id)

    def turn_off_labels(self):
        """Turn off labels for features."""
        self._display_labels = False

    @cached_property
    def _locations(self) -> Optional[Sequence[Sequence[int]]]:
        return self._ugrid.ugrid.locations if self._ugrid is not None else None

    @cached_property
    def _cell_locations(self) -> Optional[Sequence[Sequence[float]]]:
        if not self._ugrid:
            return None

        locations = [self._ugrid.ugrid.get_cell_centroid(i)[1] for i in range(self._ugrid.ugrid.cell_count)]
        return locations

    def _build_coverage(self) -> Coverage:
        """
        Mark the coverage as complete.

        This is an older API. Prefer finish() instead.

        Returns:
            The completed coverage.
        """
        # Adding polygons for a material coverage builds the coverage, so only build if that hasn't happened.
        if not self._coverage:
            self._coverage = self._builder.build()
        return self._coverage

    def _build_component(self) -> Any:
        """Finalize setting up the component."""
        self._component.cov_uuid = self._coverage.uuid
        self._write_ids()
        ids = self._component.get_component_coverage_ids()
        with DisplayOptionsHelper(self._component.main_file) as helper:
            helper.initialize_display_options(
                self._component.uuid, self._component.module_name, self._component.features,
                self._component.default_features
            )
            if not self._display_labels:
                helper.turn_off_labels()
            messages = helper.get_update_messages(self._component.cov_uuid)

        self._component.refresh_display_id_files()
        self.data.commit()

        keywords = {
            'component_coverage_ids': ids,
            'display_options': [list(message) for message in messages],
        }

        return [keywords]

    def build(self, using_xms_data: bool = False) -> CoverageComponentAndKeywords | CoverageAndComponent:
        """
        Finish building the coverage.

        The returned keywords should be passed to Query.add_coverage. They are an implementation detail. It is the
        builder's responsibility to ensure they are valid, so user code should generally just treat them as an opaque
        object and not worry about what they actually are. Neglecting to send the keywords may result in the component's
        point and component IDs being uninitialized.

        The keywords are mainly intended for models using custom `XmsData` classes or plain `Query`. If you're using the
        `XmsData` in `xms.components`, then you don't need the keywords and can pass `using_xms_data=True` to omit them.

        Returns:
            If `using_xms_data is False`, a tuple of (coverage, component, keywords). These should all be passed to
            `Query.add_coverage()`, either directly or indirectly via a custom `XmsData`.

            If `using_xms_data is True`, a tuple of (coverage, component). These should be passed to
            `xms.components.dmi.XmsData.add_coverage()`.
        """
        coverage = self._build_coverage()
        keywords = self._build_component()
        if using_xms_data:
            return coverage, self._component
        else:
            return coverage, self._component, keywords

    def _write_ids(self):
        """Write out the ID files."""
        with DisplayOptionsHelper(self._component.main_file) as helper:
            # If the component is built from an import script, then XMS will ignore the keywords we send and run the
            # component's create event to get the ID mapping. To enable this, we must use the helper to request updates
            # so the component knows what mapping to send back to XMS. If the component is built from an ActionRequest
            # callback, however, then XMS will ignore the create event instead and insist on using the keywords.
            #
            # It would be nice if XMS would do the same thing in both cases, but it doesn't, so we'll just initialize
            # the component both ways and let XMS pick whichever one it feels like today.
            if self._point_feature_ids:
                helper.request_feature_update(TargetType.point, self._point_component_ids, self._point_feature_ids)
                for feature_id, component_id in zip(self._point_feature_ids, self._point_component_ids, strict=True):
                    self._component.update_component_id(TargetType.point, feature_id, component_id)
            if self._arc_feature_ids:
                helper.request_feature_update(TargetType.arc, self._arc_component_ids, self._arc_feature_ids)
                for feature_id, component_id in zip(self._arc_feature_ids, self._arc_component_ids, strict=True):
                    self._component.update_component_id(TargetType.arc, feature_id, component_id)
            if self._polygon_feature_ids:
                helper.request_feature_update(
                    TargetType.polygon, self._polygon_component_ids, self._polygon_feature_ids
                )
                for feature_id, component_id in zip(
                    self._polygon_feature_ids, self._polygon_component_ids, strict=True
                ):
                    self._component.update_component_id(TargetType.polygon, feature_id, component_id)


def _coerce(*values) -> tuple[int, ...]:
    """
    Coerce a value to an integer.

    It's easy to trip up and pass feature and component IDs with bad types, especially if you're pulling them from
    xarray datasets. xarray just loves wrapping primitive types in its own special types. XMS silently fails to
    associate the point and component IDs if we send non-ints to it though, and it takes forever to figure that out.
    It's even worse if you don't make a lucky guess that side-steps the need to get data_objects built with debug
    symbols. We could die here and make the user-code fix the problem, but it'll probably happen enough that it's best
    to just fix it in one place here.

    This function is trivial enough it could be inlined, but using it makes it easier to centralize its documentation.
    """
    return tuple(int(value) for value in values)
