#! python3
from typing import Optional

from ._xmssnap.snap import _SnapPoint
from .snap_base import _SnapBase, SnapBase
from xms.constraint import read_grid_from_file
from xms.data_objects.parameters import UGrid, Point


class SnapPoint(_SnapPoint, SnapBase):
    """This class snaps one or more points to the closest location of a geometry."""
    def __init__(self, grid: Optional[UGrid] = None, target_cells: bool = False):
        """
        Constructor.

        Args:
            grid: The grid geometry to snap points to. If None, self.set_grid must be called to assign a grid
                before snapping points will succeed.
            target_cells: Whether to snap to cell centers in the grid. If False, snapping will target nodes in the
                grid. This parameter is ignored if no grid is passed to the constructor.
        """
        _SnapPoint.__init__(self)
        SnapBase.__init__(self)
        if grid is not None:
            self.set_grid(grid, target_cells)

    def set_grid(self, grid, target_cells):
        """Sets the geometry that will be snapped to.

        Args:
            grid (xms.data_objects.parameters.UGrid): The grid that will be targeted.
            target_cells (bool): True if the snap targets cell centers, point locations if false.
        """
        if isinstance(grid, UGrid):  # data_objects UGrid
            file = grid.cogrid_file
            if file:  # New CoGrid impl
                co_grid = read_grid_from_file(file)
            else:  # Old C++ impl for H5 file format
                co_grid = super().get_co_grid(grid)
            _SnapBase.set_grid(self, co_grid._instance, target_cells)
        else:
            _SnapBase.set_grid(self, grid._instance, target_cells)

    def get_snapped_point(self, point: Point | tuple[float, float, float]):
        """Gets a snapped location and id of the geometry.

        Raises AssertionError if self.set_grid has not been called yet or was last called with an empty grid.

        Args:
            point: The point location.

        Returns:
            A dictionary with keys 'id' and 'location'. 'id' holds a list of integers representing
            point or cell ids depending on the target set. 'location' holds a list that is parallel
            to the one in 'id'. The 'location' list is made of tuples of 3 doubles representing
            the snapped locations. In this case, the list for 'id' and 'location' should be 1 in
            length.
        """
        if isinstance(point, Point):
            point = (point.x, point.y, point.z)
        if result := super().get_snapped_point(point):
            return result
        raise AssertionError('Unable to snap point. Grid may be uninitialized or empty.')

    def get_snapped_points(self, points: list[Point] | list[tuple[float, float, float]]):
        """Gets snapped locations and ids of the geometry.

        Raises AssertionError if self.set_grid has not been called yet or was last called with an empty grid.

        Args:
            points: The point locations.

        Returns:
            A dictionary with keys 'id' and 'location'. 'id' holds a list of integers representing
            point or cell ids depending on the target set. 'location' holds a list that is parallel
            to the one in 'id'. The 'location' list is made of tuples of 3 doubles representing
            the snapped locations.
        """
        if points and isinstance(points[0], Point):
            points = [(point.x, point.y, point.z) for point in points]
        if result := super().get_snapped_points(points):
            return result
        elif not points:
            # super().get_snapped_points returns an empty dict if either there was no UGrid, the UGrid was empty, or
            # we didn't actually provide any points to snap. The first case is almost certainly an error and the second
            # is sketchy at best, but the latter is okay. Allowing it lets code pull the points out of a coverage, snap
            # them all, and iterate over the snapped points without having an extra check for whether the coverage
            # actually had any points in it in the first place.
            return {'id': tuple(), 'location': tuple()}
        raise AssertionError('Unable to snap points. Grid may be uninitialized or empty.')
