"""UgridMapper class."""

__copyright__ = "(C) Copyright Aquaveo 2022"
__license__ = "All rights reserved"

# 1. Standard Python modules
import collections
import logging
import operator
from statistics import mean

# 2. Third party modules
import numpy as np
from rtree import index

# 3. Aquaveo modules
from xms.constraint import UGrid2dFromUGrid3dCreator
from xms.data_objects.parameters import Arc, FilterLocation, Point, Polygon
from xms.grid.geometry import geometry
from xms.grid.ugrid import UGrid
from xms.snap.snap_exterior_arc import SnapExteriorArc
from xms.snap.snap_interior_arc import SnapInteriorArc

# 4. Local modules
from xms.coverage.grid import map_util
from xms.coverage.polygons import polygon_orienteer

# Type aliases
Locations = list[tuple[float, float]]  # list of xy locations


def is_non_string_sequence(obj):
    """Returns True if the object is a list, tuple, or similar, and false if it is a string or anything else.

    See https://stackoverflow.com/questions/1835018/how-to-check-if-an-object-is-a-list-or-tuple-but-not-string

    Args:
        obj: Some object.

    Returns:
        (bool): See description.
    """
    return isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)


def count_layers(cell_layers: list[int]) -> int:
    """Returns the number of layers in the cogrid.

    Args:
        cell_layers (list[int]): The cell layers list.

    Returns:
        (int): See description.
    """
    unique = set(cell_layers)
    count = len(unique)
    if count == 0:
        return 1
    return count


def compute_points_per_sheet(ugrid, layer_count) -> int:
    """Returns the points per sheet (point sheet = points between cell layers)."""
    return ugrid.point_count // (layer_count + 1)


def compute_point_sheets(cogrid) -> list[int]:
    """Computes and returns the point sheet numbers (point sheet = points between cell layers).

    Assumes ugrid is stacked and has layers.

    Returns:
        (list): List of integers, size of point count, indicating which point sheet the point is in (1-based).
    """
    ugrid = cogrid.ugrid
    if ugrid.get_cell_ordering() == UGrid.cell_ordering_enum.CELL_ORDER_UNKNOWN:
        ugrid.calculate_cell_ordering()

    layer_count = count_layers(cogrid.cell_layers)
    points_per_sheet = compute_points_per_sheet(ugrid, layer_count)
    sheet_numbers = [i for i in range(1, layer_count + 2) for n in range(points_per_sheet)]
    return sheet_numbers


def point_in_list_2d(location, locations):
    """Returns True if the location is in locations, comparing just x and y (first two items in each).

    https://stackoverflow.com/questions/14766194/testing-whether-a-numpy-array-contains-a-given-row

    Args:
        location: A point (array of float).
        locations: A list (array of array of float) of points.

    Returns:
        (bool): See description.
    """
    return any(np.equal(locations[:, :1], location[:1]).all(1))  # Compare only x and y


class UgridMapper:
    """Class to handle mapping feature objects to a UGrid.

    Some methods support the grid being 3D or 2D, but not all methods are tested for this.

    Some methods assume 3D UGrid is stacked and has layers.

    Some methods can take a sheet_min, sheet_max. A "sheet" is similar to a layer, but for points - it represents the
    layer of points between cell layers, and probably only makes sense if the grid is stacked. If sheet_min, sheet_max
    is used, the co_grid must have a point_sheets attribute consisting of a list of integers, size of the number of
    points. The point_sheets attribute is normally not part of cogrid - you must add it yourself. See
    compute_point_sheets() in this file.
    """

    POINT_IDX = 0
    T_VALUE = 1
    ALLMATZONES = -1
    XY_TOL_FACTOR = 1e8  # From GMS FEMWATER (FWMAP_TOL). If the auto tolerance is not suitable, you can pass your own.

    def __init__(self, co_grid, cell_materials=None, logger=None, xy_tol=None):
        """Initializes the class.

        Args:
            co_grid (UGrid3d): The UGrid (can be 3D or 2D).
            cell_materials (list[int] | None): Material IDs at cells (optional).
            logger (Logger): Python logger.
            xy_tol (float|None): xy tolerance. If None, a tolerance is calculated based on the grid extents.
        """
        self.co_grid = co_grid  # Not private because we access it occasionally
        self._ugrid = co_grid.ugrid  # Because 'self.co_grid.ugrid' should not be used in a loop
        self._cell_materials = cell_materials
        self._logger = logger if logger else logging.getLogger('xms.coverage')

        self._is2d = self.co_grid.check_all_cells_2d()
        self._co_grid_2d = None  # 2D version of self.co_grid (or self.co_grid if it's already 2D)
        self._ugrid_2d = None  # Because 'self._co_grid_2d.ugrid' should not be used in a loop
        self._ugrid_locations = self._ugrid.locations
        self._layer_count = None
        self._points_per_sheet = None
        self._snap_arc_exterior = None
        self._snap_arc_interior = None
        self._xy_tol = _compute_xy_tol(self._ugrid) if xy_tol is None else xy_tol
        self._rtree_2d = None  # A 2d (x,y) rtree used to speed up locating grid points

        # Cached data, so if we want it again we don't have to compute it again
        self._top_faces = []  # (list[tuple[int, int]]) list of tuples of ugrid cell index and face index
        self._top_points = []  # list of all point indices on the top of the UGrid

    # def _make_hashable(self, thing):
    #     """If the thing is a list (or list of lists etc), return the equivalent tuple (or tuple of tuples)."""
    #     if type(thing) == list:
    #         # Use recursion and generator expression to make tuples (https://stackoverflow.com/questions/16940293)
    #         return tuple(self._make_hashable(item) for item in thing)
    #     return thing

    def get_ugrid_points_at_point(
        self,
        point,
        z_min: float | None = None,
        z_max: float | None = None,
        sheet_min: int | None = None,
        sheet_max: int | None = None,
        inclusive: bool = True
    ):
        """Finds the UGrid points at point x,y and between z_min and z_max (or sheet_min, sheet_max), and their weights.

        'point' can be either a xms.data_objects.parameters.Point or xyz tuple. You must pass either z_min, z_max or
        sheet_min, sheet_max, but not both.

        Args:
            point (xms.data_objects.parameters.Point | tuple[float, float]): The point.
            z_min (float): Bottom of range (typically a well screen) to search.
            z_max (float): Top of range (typically a well screen) to search.
            sheet_min (int): Minimum point sheet (interface between cell layers).
            sheet_max (int): Maximum point sheet (interface between cell layers).
            inclusive (bool): If True, ugrid points that are right at z_min, z_max, sheet_min, sheet_max are included.

        Return:
             (tuple[list[int], list[float]): Found ugrid point indexes and their weights, in no particular order.
        """
        if not point:
            return [], []

        self._check_args_vertical_range(z_min, z_max, sheet_min, sheet_max)

        # Find the points
        pt = (point.x, point.y, point.z) if isinstance(point, Point) else point
        point_map = {}  # dict[float, int] of z, order encountered (not vertical order) used when finding weights
        ugrid_points: list[int] = []

        if not self._rtree_2d:
            self._rtree_2d = _build_2d_rtree(self._ugrid_locations)
        tol = self._xy_tol  # for short on next line
        mn_x, mn_y, mx_x, mx_y = pt[0] - tol, pt[1] - tol, pt[0] + tol, pt[1] + tol
        intersected_points = list(self._rtree_2d.intersection((mn_x, mn_y, mx_x, mx_y), objects='raw'))
        for ugrid_point_idx in intersected_points:
            z = self._ugrid_locations[ugrid_point_idx][2]
            if self._in_vertical_range(z, z_min, z_max, ugrid_point_idx, sheet_min, sheet_max, inclusive):
                point_map[z] = len(ugrid_points)
                ugrid_points.append(ugrid_point_idx)

        # This is the older code, before the rtree approach above, that would check every UGrid point.
        # for ugrid_point_idx, location in enumerate(self._ugrid_locations):
        #     x, y, z = location
        #     if (
        #         math.isclose(x, pt[0], abs_tol=self._xy_tol) and  # noqa: W504 (line break)
        #         math.isclose(y, pt[1], abs_tol=self._xy_tol) and  # noqa: W504 (line break)
        #         self._in_vertical_range(z, z_min, z_max, ugrid_point_idx, sheet_min, sheet_max, inclusive)
        #     ):
        #         point_map[z] = len(ugrid_points)
        #         ugrid_points.append(ugrid_point_idx)

        if not ugrid_points:
            self._log_warning_no_points(point, pt, z_min, z_max, sheet_min, sheet_max)
            weights = []
        else:
            # Get weights
            weights = self._get_point_weights(point_map, ugrid_points, z_min, z_max)
        return ugrid_points, weights

    def get_outer_faces(self, face_orientation) -> list[tuple[int, int]]:
        """Returns list of faces on the outside of the UGrid as list of (cell index, face index) tuples.

        Args:
            face_orientation (UGrid.face_orientation_enum|None): Optional face orientation (typically top).

        Returns:
            (list[tuple[int, int]]): List of tuples of cell index, face index.
        """
        outer_faces = []
        cell_count = self._ugrid.cell_count
        for cell_idx in range(cell_count):
            face_count = self._ugrid.get_cell_3d_face_count(cell_idx)
            for face_idx in range(face_count):
                if -1 == self._ugrid.get_cell_3d_face_adjacent_cell(cell_idx, face_idx):
                    if face_orientation:
                        orientation = self._ugrid.get_cell_3d_face_orientation(cell_idx, face_idx)
                        if orientation == face_orientation:
                            outer_faces.append((cell_idx, face_idx))
                    else:
                        outer_faces.append((cell_idx, face_idx))
                # else:  It _might_ be faster to store visited faces somehow so that we don't visit them twice
                #        but let's not optimize prematurely.

        return outer_faces

    def get_vertical_faces_intersected_by_arc(
        self,
        arc: Arc | list[tuple[float, float, float]],
        z_min: float | None = None,
        z_max: float | None = None,
        sheet_min: int | None = None,
        sheet_max: int | None = None
    ):
        """Returns list of faces (cell, face index tuples) intersected (snapped) by arc, organized by layer.

        Uses xms.snap and the 2D UGrid, so grid is assumed to be "layered".

        Args:
            arc: The arc.
            z_min: Bottom of range to search.
            z_max: Top of range to search.
            sheet_min: Minimum node sheet (interface between cell layers).
            sheet_max: Maximum node sheet (interface between cell layers).

        Returns:
            (list[list[tuple[int, int]]): List of faces (cell index, face index tuples) in each layer.
        """
        if not arc:
            return []

        self._check_args_vertical_range(z_min=z_min, z_max=z_max, sheet_min=sheet_min, sheet_max=sheet_min)

        layer_face_lists = []

        # Get snap data
        snap_output, mn, mx = self._get_snap_data(arc, exterior=False, only_points_on_arc=False)
        if snap_output is None or snap_output['location'] is None:
            return layer_face_lists
        pts2d = snap_output['id']

        # Make sure we have the layer count
        if self._layer_count is None:
            self._layer_count = count_layers(self.co_grid.cell_layers)
        cells_per_layer = self._ugrid.cell_count // self._layer_count

        # Make sure we have the points per sheet
        if self._points_per_sheet is None:
            self._points_per_sheet = compute_points_per_sheet(self._ugrid, self._layer_count)

        # Loop through layers
        for layer in range(self._layer_count):
            cell_offset = layer * cells_per_layer
            pt_offset = layer * self._points_per_sheet
            layer_faces = []  # tuples of cell_idx, face_idx

            # Loop through snapped points
            for i in range(1, len(pts2d)):
                cell_face_tuple = self._cell_3d_face_from_2d_edge((pts2d[i - 1], pts2d[i]), cell_offset, pt_offset)
                if self._face_in_vertical_range(cell_face_tuple, z_min, z_max, sheet_min, sheet_max):
                    layer_faces.append(cell_face_tuple)

            layer_face_lists.append(layer_faces)

        return layer_face_lists

    def _face_in_vertical_range(
        self,
        cell_face_tuple: tuple[int, int],
        z_min: float | None = None,
        z_max: float | None = None,
        sheet_min: int | None = None,
        sheet_max: int | None = None,
    ) -> bool:
        """Returns true if the face is within the vertical range, or no range is specified.

        Args:
            cell_face_tuple: Cell index and face index.
            z_min (float|None): Bottom of range (typically a well screen) to search.
            z_max (float|None): Top of range (typically a well screen) to search.
            sheet_min (int|None): Minimum node sheet (interface between cell layers).
            sheet_max (int|None): Maximum node sheet (interface between cell layers).

        Returns:
            (bool): True if the face is in the vertical range.
        """
        if z_min is not None or sheet_min is not None:
            face_point_idxs = self._ugrid.get_cell_3d_face_points(cell_face_tuple[0], cell_face_tuple[1])
            if z_min is not None:
                # Return true if the face point average z is in the range
                z = mean([self._ugrid_locations[i][2] for i in face_point_idxs])  # average z
                return z_min <= z <= z_max
            else:
                # Return true if all face points are in the sheet range
                for point_idx in face_point_idxs:
                    sheet = self.co_grid.point_sheets[point_idx]
                    if sheet < sheet_min or sheet > sheet_max:
                        return False
                else:
                    return True
        else:
            return True

    def get_ugrid_points_on_arc(
        self,
        arc,
        exterior: bool,
        material: int | None = None,
        z_min: float | None = None,
        z_max: float | None = None,
        sheet_min: int | None = None,
        sheet_max: int | None = None
    ):
        """Returns list of tuples of 3D UGrid point indexes and their t values (normalized lengths) along arc.

        Meant for use when arc exactly matches ugrid cell edges. Uses xms.snap and the 2D UGrid.

        You can pass either z_min, z_max or sheet_min, sheet_max, or neither, but not both.

        Args:
            arc (xms.data_objects.parameters.Arc | list[tuple[float, float, float]]): The arc.
            exterior (bool): True if arc is on the ugrid border.
            material (int | None): Material ID to restrict mapping to or -1 to ignore materials (match all).
            z_min (float | None): Bottom of range to search.
            z_max (float | None): Top of range to search.
            sheet_min (int | None): Minimum node sheet (interface between cell layers).
            sheet_max (int | None): Maximum node sheet (interface between cell layers).

        Returns:
            (tuple[list[int], list[float]]): point indexes and t values.
        """
        if arc is None or (is_non_string_sequence(arc) and len(arc) == 0):
            return [], []

        self._check_args_vertical_range(z_min=z_min, z_max=z_max, sheet_min=sheet_min, sheet_max=sheet_min)

        # Get snap data
        snap_output, mn, mx = self._get_snap_data(arc, exterior, only_points_on_arc=False)
        if snap_output is None or snap_output['location'] is None or mn is None:
            return [], []

        # Get the points
        arc_xyzs = self._get_arc_locations(arc)
        if self._is2d:  # The ugrid passed was 2D
            point_idxs = snap_output['id']
        else:
            point_idxs = self._3d_points_from_2d_points(
                material, z_min, z_max, sheet_min, sheet_max, snap_output['location'], mn, mx
            )

        # Get the t values
        ugrid_locations = self._ugrid_locations  # for short
        t_values = [map_util.distance_along_arc(arc_xyzs, ugrid_locations[i], normalized=True) for i in point_idxs]

        if sheet_min is not None:
            point_idxs, t_values = self._sort_by_sheet(point_idxs, t_values)
        else:
            point_idxs, t_values = self._sort_by_t_values(point_idxs, t_values)
        return point_idxs, t_values

    def get_ugrid_top_points_in_polygon(self, polygon):
        """Returns indexes of top points that are inside the polygon which may include holes.

        'polygon' can be either a xms.data_objects.parameters.Polygon or a list[list[tuple[float, float, float]]].

        Args:
            polygon (xms.data_objects.parameters.Polygon | list[list[tuple[float, float, float]]]): The polygon.

        Returns:
            (list[int]): See description.
        """
        if not polygon:
            return []

        if not self._top_points:
            self.get_top_faces_and_points()

        # Loop through points and collect those that are inside the polygon
        points_in_polygon = []
        poly_points = self._get_polygon_point_lists(polygon)
        min_xy, max_xy = map_util.extents_2d(poly_points[0])
        for point_idx in self._top_points:
            xyz = self._ugrid_locations[point_idx]
            if not map_util.point_in_rectangle(min_xy, max_xy, xyz):
                continue
            if map_util.point_in_polygon_2d(poly_points[0], xyz):
                # See if it's inside any holes
                for poly in poly_points[1:]:
                    if map_util.point_in_polygon_2d(poly, xyz):
                        break
                else:  # for/else - if we hit the else, we didn't break, so it's not in any holes
                    points_in_polygon.append(point_idx)

        return points_in_polygon

    def get_ugrid_points_in_polygon(
        self,
        polygon,
        strictly_in: bool = False,
        z_min: float | None = None,
        z_max: float | None = None,
        sheet_min: int | None = None,
        sheet_max: int | None = None
    ):
        """Returns indexes of top points that are inside the polygon which may include holes.

        'polygon' can be either a xms.data_objects.parameters.Polygon or a list[list[tuple[float, float, float]]].

        You can pass either z_min, z_max or sheet_min, sheet_max, or neither, but not both.

        Args:
            polygon (xms.data_objects.parameters.Polygon | list[list[tuple[float, float, float]]]): The polygon.
            strictly_in (bool): True if points on the polygon should NOT be included.
            z_min: float | None = None,
            z_max: float | None = None,
            sheet_min (int|None): Minimum node sheet (interface between cell layers).
            sheet_max (int|None): Maximum node sheet (interface between cell layers).

        Returns:
            (list[int]): See description.
        """
        self._check_args_vertical_range(z_min=z_min, z_max=z_max, sheet_min=sheet_min, sheet_max=sheet_max)

        if not polygon:
            return []

        # Loop through points and collect those that are inside the polygon
        points_in_polygon = []
        poly_points = self._get_polygon_point_lists(polygon)
        min_xy, max_xy = map_util.extents_2d(poly_points[0])
        for point_idx in range(self._ugrid.point_count):
            xyz = self._ugrid_locations[point_idx]
            if not map_util.point_in_rectangle(min_xy, max_xy, xyz):
                continue
            if map_util.point_in_polygon_2d(poly_points[0], xyz, strictly_in):
                # See if it's inside any holes
                for poly in poly_points[1:]:
                    if map_util.point_in_polygon_2d(poly, xyz, strictly_in):
                        break
                else:  # for/else - if we hit the else, we didn't break, so it's not in any holes
                    if self._in_vertical_range(
                        z=xyz[2],
                        z_min=z_min,
                        z_max=z_max,
                        point_idx=point_idx,
                        sheet_min=sheet_min,
                        sheet_max=sheet_max
                    ):
                        points_in_polygon.append(point_idx)

        return points_in_polygon

    def get_ugrid_side_faces_on_arc(
        self,
        arc: Arc | list[tuple[float, float, float]],
        material: int | None = None,
        z_min: float | None = None,
        z_max: float | None = None,
        sheet_min: int | None = None,
        sheet_max: int | None = None
    ):
        """Returns list of side faces (cell index, face index) that are on the arc.

        Args:
            arc (xms.data_objects.parameters.Arc | list[tuple[float, float, float]]): The arc.
            material (int): Material ID to restrict mapping to or -1 to ignore materials (match all).
            z_min: Bottom of range to search.
            z_max: Top of range to search.
            sheet_min: Minimum node sheet (interface between cell layers).
            sheet_max: Maximum node sheet (interface between cell layers).

        Returns:
            (list[(int, int)]): List of faces (cell index, face index).
        """
        if not arc:
            return []

        self._check_args_vertical_range(z_min=z_min, z_max=z_max, sheet_min=sheet_min, sheet_max=sheet_min)

        faces: set[tuple[int, int]] = set()  # We use a set to remove duplicates but convert it to a list at the end
        arc_xyzs = self._get_arc_locations(arc)
        mn, mx = map_util.extents_2d(arc_xyzs)
        side_faces = self.get_outer_faces(UGrid.face_orientation_enum.ORIENTATION_SIDE)
        for cell_idx, face_idx in side_faces:
            if self._cell_material_matches(cell_idx, material):
                face_point_idxs = self._ugrid.get_cell_3d_face_points(cell_idx, face_idx)
                face_xyzs = [self._ugrid_locations[i] for i in face_point_idxs]
                face_centroid_2d = map_util.average_point_2d(face_xyzs)
                if map_util.point_in_rectangle(mn, mx, face_centroid_2d):
                    if self._face_in_vertical_range((cell_idx, face_idx), z_min, z_max, sheet_min, sheet_max):
                        if map_util.point_on_polyline_2d(face_centroid_2d, arc_xyzs, self._xy_tol):
                            faces.add((cell_idx, face_idx))

        sorted_faces = self._sorted_faces(list(faces))
        return sorted_faces

    def get_ugrid_top_faces_in_polygon(self, polygon):
        """Returns list of top faces (cell index, face index) that are inside the polygon.

        'polygon' can be either a xms.data_objects.parameters.Polygon or a list[list[tuple[float, float, float]]]. It
        may include holes where, if it's a list of points, the holes are the lists of points after the first list.

        Args:
            polygon (xms.data_objects.parameters.Polygon | list[list[tuple[float, float, float]]]): The polygon.

        Returns:
            (list[tuple[int, int]]): List of faces (cell index, face index).
        """
        if not polygon:
            return []

        if not self._top_faces:
            self._get_top_faces()

        faces_in_polygon = []
        poly_points = self._get_polygon_point_lists(polygon)
        min_xy, max_xy = map_util.extents_2d(poly_points[0])
        for face in self._top_faces:
            cell_idx = face[0]  # First tuple item is the cell_idx
            # Cell centroid in xy should be the same as face centroid in xy
            _, xyz = self._ugrid.get_cell_centroid(cell_idx)
            if not map_util.point_in_rectangle(min_xy, max_xy, xyz):
                continue
            if map_util.point_in_polygon_2d(poly_points[0], xyz):
                # See if it's inside any holes
                for poly in poly_points[1:]:
                    if map_util.point_in_polygon_2d(poly, xyz):
                        break
                else:  # for/else - if we hit the else, we didn't break, so it's not in any holes
                    faces_in_polygon.append(face)

        return faces_in_polygon

    def get_top_faces_and_points(self):
        """Finds and stores the UGrid top faces and the points of those faces."""
        self._get_top_faces()
        if self._is2d:
            self._top_points = list(range(self._ugrid.point_count))
        else:
            self._top_points = self._get_unique_face_points(self._top_faces)
        return self._top_faces, self._top_points

    def _ensure_snapping_is_initialized(self, exterior: bool) -> None:
        """Initializes snapping.

        Args:
            exterior (bool): True if arc is on the ugrid border.
        """
        if (exterior and self._snap_arc_exterior) or (not exterior and self._snap_arc_interior):
            return  # We've already done it

        # Make sure we have a 2D grid
        if self._co_grid_2d is None:
            if self._is2d:
                self._co_grid_2d = self.co_grid
            else:
                self._co_grid_2d = self._create_2d_ugrid_from_3d_ugrid(self.co_grid)
            self._ugrid_2d = self._co_grid_2d.ugrid

        if exterior:
            self._snap_arc_exterior = SnapExteriorArc()
            self._snap_arc_exterior.set_grid(grid=self._co_grid_2d, target_cells=False)
        else:
            self._snap_arc_interior = SnapInteriorArc()
            self._snap_arc_interior.set_grid(grid=self._co_grid_2d, target_cells=False)

    @staticmethod
    def _create_2d_ugrid_from_3d_ugrid(co_grid):
        """Creates a 2D UGrid from the top faces of the 3D UGrid.

        Returns:
            cogrid_2d
        """
        # Use the tool designed for this
        creator = UGrid2dFromUGrid3dCreator()
        cogrid_2d = creator.create_2d_cogrid(co_grid, 'Top')
        return cogrid_2d

    def _cell_material_matches(self, cell_idx, material):
        """Returns True if the material of the cell matches.

        Args:
            cell_idx (int): The cell index.
            material (int): Material ID to restrict mapping to or -1 to ignore materials (match all).

        Returns:
            (bool): See description.
        """
        return (
            self._cell_materials is None or material is None or material == UgridMapper.ALLMATZONES or  # noqa: W504
            round(self._cell_materials[cell_idx]) == material
        )  # noqa: W503 (line break)

    def _matching_material_in_adjacent_cells(self, point_idx: int, material: int | None):
        """Returns True if any of the cells touching the point have the material specified on the arc.

        Args:
            point_idx (int): Index of UGrid point.
            material (int|None): Material ID to restrict mapping to or None to ignore materials (match all).

        Returns:
            (bool): See description.
        """
        if self._cell_materials is None or material is None or material == UgridMapper.ALLMATZONES:
            return True

        rv = False
        adjacent_cells = self._ugrid.get_point_adjacent_cells(point_idx)
        for cell_idx in adjacent_cells:
            if round(self._cell_materials[cell_idx]) == material:
                rv = True
                break
        return rv

    def _snap_output_has_errors(self, snap_output, arc):
        """Checks for errors in the snap output.

        Args:
            snap_output (dict): output from snapper
            arc (xms.data_objects.parameters.Arc | list[tuple[float, float, float]]): The arc.

        Returns:
            (bool): True if there are snap errors
        """
        if 'location' not in snap_output or len(snap_output['location']) < 1:
            arc_id = arc.id if isinstance(arc, Arc) else -1
            self._logger.warning(f'Unable to snap arc id: {arc_id} to ugrid.')
            return True
        return False

    def _sort_by_sheet(self, point_idxs, t_values) -> tuple[list[int], list[float]]:
        """Sorts the points by sheet, then by t_value.

        Args:
            point_idxs (list[int]): List of 3D UGrid point indexes.
            t_values (list[float]): List of t values.

        Returns:
            (tuple[list[int], list[float]]): Point indexes, and t_values.
        """
        # Sort by sheet, then t_value
        sheet_numbers = [self.co_grid.point_sheets[point_idx] for point_idx in point_idxs]
        sorted_tuples = sorted(zip(point_idxs, sheet_numbers, t_values), key=operator.itemgetter(1, 2))
        sorted_points = [sorted_tuple[0] for sorted_tuple in sorted_tuples]
        sorted_t_values = [sorted_tuple[2] for sorted_tuple in sorted_tuples]
        return sorted_points, sorted_t_values

    @staticmethod
    def _sort_by_t_values(point_idxs, t_values) -> tuple[list[int], list[float]]:
        """Sorts the points by t_value.

        Args:
            point_idxs (list[int]): List of 3D UGrid point indexes.
            t_values (list[float]): List of t values.

        Returns:
            (tuple[list[int], list[float]]): Point indexes and t_values.
        """
        sorted_tuples = sorted(zip(point_idxs, t_values), key=lambda pair: pair[1])
        sorted_points = [point_index for point_index, _ in sorted_tuples]
        sorted_t_values = [t_value for _, t_value in sorted_tuples]
        return sorted_points, sorted_t_values

    @staticmethod
    def _sorted_faces(faces):
        """Returns a sorted list of faces.

        Args:
            faces (list[(int, int)]): List of faces (cell index, face index).
        """
        return sorted(faces, key=lambda pair: (pair[0], pair[1]))

    @staticmethod
    def _get_arc_locations(arc):
        """Returns a list of all the arc locations, including nodes and vertices.

        Args:
            arc (xms.data_objects.parameters.Arc | list[tuple[float, float, float]]): The arc.

        Returns:
            list of xyz tuples
        """
        if isinstance(arc, Arc):
            return [(point.x, point.y, point.z) for point in arc.get_points(FilterLocation.PT_LOC_ALL)]
        return arc

    @staticmethod
    def _get_polygon_point_lists(polygon):
        """Returns polygon point lists based on what type of polygon we were given.

        Args:
            polygon (Polygon|list[list[tuple(float, float)]]): The polygon.

        Returns:
            (list[list[tuple(float, float)]]): List of polygons.
        """
        if isinstance(polygon, Polygon):
            poly_points = polygon_orienteer.get_polygon_point_lists(polygon)
        else:
            poly_points = polygon
        return poly_points

    def _get_unique_face_points(self, faces) -> set[int]:
        """Returns list of unique UGrid point indexes on the given faces.

        Args:
            faces (list[tuple[int, int]]): list of (cell index, face index) tuples.

        Returns:
            (set[int]): Set of unique points that are on the faces.
        """
        unique_points = set()  # Use a set to eliminate duplicates
        if self._is2d:
            for cell_idx, _ in faces:
                face_points = self._ugrid.get_cell_points(cell_idx)
                unique_points.update(face_points)
        else:
            for cell_idx, face_idx in faces:
                face_points = self._ugrid.get_cell_3d_face_points(cell_idx, face_idx)
                unique_points.update(face_points)
        return unique_points

    @staticmethod
    def _compute_point_weights(z_min, z_max, point_map):
        r"""Compute normalized (0.0 to 1.0) values for each UGrid point based on length of well screen.

        ::
             0  -1  -2  -3  -4  -5  -6  -7  -8  -9  -10        Well screen (z_max and z_min): 0.0 to -10.0
             |---|---|---|---|---|---|---|---|---|---|         Grid points (p):  -1.0, -3.0, -7.0
                                                               Bisectors (b): -2.0, -5.0
             |   *   |   *       |        *          |         Weights: 0.2, 0.3, 0.5
            z_max  p   b   p       b        p        z_min
        """
        weights = [0.0] * len(point_map)
        screen_length = z_max - z_min
        current_bottom = z_min
        keys = sorted(point_map.keys())
        for i, key in enumerate(keys):
            next_key = keys[i + 1] if i + 1 < len(keys) else None
            if next_key:
                current_top = key + next_key
                current_top /= 2
            else:
                current_top = z_max
            length = current_top - current_bottom
            weights[point_map[key]] = length / screen_length
            current_bottom = current_top
        return weights

    def _get_snap_data(self, arc, exterior: bool, only_points_on_arc: bool):
        """Uses xms.snap to return tuple of: 2D UGrid points on arc, extents min, extents max.

        Args:
            arc (xms.data_objects.parameters.Arc | list[tuple[float, float, float]]): The arc.
            exterior (bool): True if arc is on the ugrid border.
            only_points_on_arc (bool): If True, only points that lie on the arc are included in the results.

        Returns:
            tuple(list[(float, float)], (float, float, float), (float, float, float)) | None: See description.
        """
        self._ensure_snapping_is_initialized(exterior)
        if exterior:
            snap_output = self._snap_arc_exterior.get_snapped_points(arc)
        else:
            snap_output = self._snap_arc_interior.get_snapped_points(arc)
        if self._snap_output_has_errors(snap_output, arc):
            return None, None, None

        # Post-process snap points to make sure that all are on the arc (xms.snap can include points beyond the arc)
        if only_points_on_arc:
            arc_xyzs = self._get_arc_locations(arc)
            new_pts = [p for p in snap_output['location'] if map_util.point_on_polyline_2d(p, arc_xyzs, self._xy_tol)]
            mn, mx = map_util.extents_2d(new_pts)
        else:
            mn, mx = map_util.extents_2d(snap_output['location'])
        return snap_output, mn, mx

    def _is_edge_on_face(self, cell_idx: int, edge: tuple[int, int], face_idx: int) -> bool:
        """Returns true if the edge is on the face.

        Args:
            cell_idx (int): The cell index.
            edge (tuple[int, int]): Edge defined by two point indices
            face_idx (int): The face index.

        Returns:
            (bool): See description.
        """
        face_points = self._ugrid.get_cell_3d_face_points(cell_idx, face_idx)
        found = [False, False]
        for point_idx in face_points:
            if not found[0] and point_idx == edge[0]:
                found[0] = True
            elif not found[1] and point_idx == edge[1]:
                found[1] = True
            if found[0] and found[1]:
                break
        return found[0] and found[1]

    def _cell_3d_face_from_2d_edge(self, edge: tuple[int, int], cell_offset: int, pt_offset: int) -> tuple[int, int]:
        """Returns a face (cell, face index tuple) on the top layer of the 3D grid.

        Args:
            edge (tuple[int, int]): Edge on 2D grid defined by two point indices.
            cell_offset: Offset to add to cell index to convert 2D cell index into 3D cell index.
            pt_offset: Offset to add to point index to convert 2D point index into 3D point index.

        Returns:
            (tuple[int, int]): See description.
        """
        adjacent_cells = self._ugrid_2d.get_edge_adjacent_cells(edge)
        cell_idx_2d = adjacent_cells[0]
        if len(adjacent_cells) > 1:
            # From first cell, find index of face between it and the adjacent cell
            cell_idx_3d = cell_idx_2d + cell_offset
            for face_idx in range(self._ugrid.get_cell_3d_face_count(cell_idx_3d)):
                adjacent_cell = self._ugrid.get_cell_3d_face_adjacent_cell(cell_idx_3d, face_idx)
                if adjacent_cell == adjacent_cells[1] + cell_offset:
                    return cell_idx_3d, face_idx
        else:
            # No adjacent cell. Find face with the edge by finding which face includes both edge points
            cell_idx_3d = cell_idx_2d + cell_offset
            edge_3d = (edge[0] + pt_offset, edge[1] + pt_offset)
            for face_idx in range(self._ugrid.get_cell_3d_face_count(cell_idx_3d)):
                if (
                    self._ugrid.get_cell_3d_face_orientation(cell_idx_3d, face_idx) ==  # noqa: W504 (line break)
                    UGrid.face_orientation_enum.ORIENTATION_SIDE
                ):
                    if self._is_edge_on_face(cell_idx_3d, edge_3d, face_idx):
                        return cell_idx_3d, face_idx

    def _cell_below(self, cell_idx) -> int | None:
        """Returns the index of the cell below cell_idx based on face orientation, or None if there is none.

        Args:
            cell_idx (int): The cell index.

        Returns:
            (int | None): The cell below.
        """
        for face_idx in range(self._ugrid.get_cell_3d_face_count(cell_idx)):
            if (
                self._ugrid.get_cell_3d_face_orientation(cell_idx, face_idx) ==  # noqa: W504 (line break)
                UGrid.face_orientation_enum.ORIENTATION_BOTTOM
            ):
                return self._ugrid.get_cell_3d_face_adjacent_cell(cell_idx, face_idx)
        return -1

    def _get_point_weights(self, point_map, ugrid_points, z_min: float, z_max: float) -> list[float]:
        """Computes and returns the weights.

        Args:
            point_map (dict[float, int]):  z, order encountered (not vertical order) used when finding weights:
            ugrid_points (list[int]): List of UGrid points at the location.
            z_min (float): Bottom of range (typically a well screen) to search.
            z_max (float): Top of range (typically a well screen) to search.

        Returns:
            (list[float]): List of weights, len(ugrid_points).
        """
        if z_max is not None:
            weights = self._compute_point_weights(z_min, z_max, point_map)
        else:  # Using sheets
            weights = [1 / len(ugrid_points)] * len(ugrid_points)
        return weights

    def _in_vertical_range(
        self,
        z: float | None = None,
        z_min: float | None = None,
        z_max: float | None = None,
        point_idx: int | None = None,
        sheet_min: int | None = None,
        sheet_max: int | None = None,
        inclusive: bool = True
    ) -> bool:
        """Returns True if the point is withing the vertical range defined by either z or sheet range.

        Args:
            z (float|None): Z of the location.
            z_min (float|None): Bottom of range (typically a well screen) to search.
            z_max (float|None): Top of range (typically a well screen) to search.
            point_idx (int|None): Index of UGrid point.
            sheet_min (int|None): Minimum node sheet (interface between cell layers).
            sheet_max (int|None): Maximum node sheet (interface between cell layers).
            inclusive (bool): If True, ugrid points that are right at z_min, z_max, sheet_min, sheet_max are included.

        Returns:
            (bool): See description.
        """
        if z_min is not None:
            return z_min <= z <= z_max if inclusive else z_min < z < z_max
        elif sheet_min is not None:
            sheet = self.co_grid.point_sheets[point_idx]
            return sheet_min <= sheet <= sheet_max if inclusive else sheet_min < sheet < sheet_max
        return True

    def _check_args_vertical_range(
        self, z_min: float | None, z_max: float | None, sheet_min: int | None, sheet_max: int | None
    ):
        """Raises an exception if the arguments aren't valid.

        Args:
            z_min (float|None): Bottom of range (typically a well screen) to search.
            z_max (float|None): Top of range (typically a well screen) to search.
            sheet_min (int|None): Minimum node sheet (interface between cell layers).
            sheet_max (int|None): Maximum node sheet (interface between cell layers).
        """
        z_arg_count = int(z_min is not None) + int(z_max is not None)
        sheet_arg_count = int(sheet_min is not None) + int(sheet_max is not None)
        if not ((z_arg_count in {0, 2} and sheet_arg_count in {0, 2}) and (z_arg_count + sheet_arg_count != 4)):
            raise RuntimeError('You must pass z_max and z_min, or sheet_min and sheet_max, but not both.')
        if sheet_arg_count > 0 and (self.co_grid and not hasattr(self.co_grid, 'point_sheets')):
            raise RuntimeError(
                'Use of sheet_min/sheet_max requires the 3D grid to have a point_sheets attribute'
                ' consisting of a list of integers size of number of points.'
            )

    def _log_warning_no_points(
        self, point, point_location, z_min: float, z_max: float, sheet_min: float, sheet_max: float
    ):
        """Writes a warning to the log that there are no UGrid points at the location.

        Args:
            point (xms.data_objects.parameters.Point | tuple[float, float]): The point.
            point_location (tuple[float]): x,y,z tuple.
            z_min (float): Bottom of range (typically a well screen) to search.
            z_max (float): Top of range (typically a well screen) to search.
            sheet_min (int): Minimum point sheet (interface between cell layers).
            sheet_max (int): Maximum point sheet (interface between cell layers).
        """
        s = point.id if isinstance(point, Point) else str(point_location)
        if z_min is not None:
            self._logger.warning(f'No UGrid points lie below point {s} between {z_max} and {z_min}')
        else:
            self._logger.warning(f'No UGrid points lie below point {s} between sheet {sheet_min} and sheet {sheet_max}')

    def _3d_points_from_2d_points(
        self, material: int | None, z_min: float | None, z_max: float | None, sheet_min: int | None,
        sheet_max: int | None, snap_locations, mn, mx
    ):
        """Returns the list of 3D grid point indices given indices on the 2D grid.

        Args:
            material (int | None): Material ID to restrict mapping to or -1 to ignore materials (match all).
            z_min (float | None): Bottom of range to search.
            z_max (float | None): Top of range to search.
            sheet_min (int | None): Minimum node sheet (interface between cell layers).
            sheet_max (int | None): Maximum node sheet (interface betw<een cell layers).
            snap_locations (list[tuple[float, float, float]]):
            mn (tuple[float, float, float]): Min extent of snap locations
            mx (tuple[float, float, float]): Max extent of snap locations

        Returns:
            (list[int]): Grid point indices.
        """
        point_idxs: list[int] = []
        start, stop = self._get_locations_range(sheet_min, sheet_max)
        for point_idx in range(start, stop):
            xyz = self._ugrid_locations[point_idx]
            if (
                map_util.point_in_rectangle(mn, mx, xyz) and  # noqa: W504
                point_in_list_2d(xyz, snap_locations) and  # noqa: W504
                self._matching_material_in_adjacent_cells(point_idx, material) and  # noqa: W504
                self._in_vertical_range(
                    z=xyz[2], z_min=z_min, z_max=z_max, point_idx=point_idx, sheet_min=sheet_min, sheet_max=sheet_max
                )  # noqa: W504
            ):
                point_idxs.append(point_idx)
        return point_idxs

    def _get_top_faces(self):
        """Finds and stores the UGrid top faces and the points of those faces."""
        if self._is2d:
            self._top_faces = [(cell, 0) for cell in range(self._ugrid.cell_count)]
        else:
            self._top_faces = self.get_outer_faces(UGrid.face_orientation_enum.ORIENTATION_TOP)
        return self._top_faces

    def _get_locations_range(self, sheet_min, sheet_max) -> tuple[int, int]:
        """Returns starting and stopping indexes of the UGrid locations given point sheet range.

        Args:
            sheet_min (int|None): Minimum node sheet (interface between cell layers).
            sheet_max (int|None): Maximum node sheet (interface between cell layers).

        Returns:
            (tuple[int, int]): The
        """
        if sheet_min is not None:
            if self._layer_count is None:
                self._layer_count = count_layers(self.co_grid.cell_layers)
            if self._points_per_sheet is None:
                self._points_per_sheet = compute_points_per_sheet(self._ugrid, self._layer_count)
            start = (sheet_min - 1) * self._points_per_sheet
            stop = sheet_max * self._points_per_sheet
            return start, stop
        return 0, len(self._ugrid_locations)


def _compute_xy_tol(ugrid):
    """Returns a tolerance based on the extents of the grid (patterned after gmComputeXyTol())."""
    mn, mx = ugrid.extents
    d = geometry.distance_2d(mx, mn)
    factor = 1.0 / UgridMapper.XY_TOL_FACTOR
    xy_tol = d * factor
    if xy_tol < factor:
        xy_tol = factor
    return xy_tol


def _rtree_2d_insert_generator(grid_locations: Locations):
    """This generator function is supposed to be a faster way to populate the rtree?

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    for i, location in enumerate(grid_locations):
        yield i, (location[0], location[1], location[0], location[1]), i


def _build_2d_rtree(grid_locations: Locations):
    """Builds an rtree using a generator function which is supposed to be faster.

    https://rtree.readthedocs.io/en/latest/performance.html#use-stream-loading
    """
    p = index.Property()
    p.dimension = 2
    return index.Index(_rtree_2d_insert_generator(grid_locations), properties=p)
