"""map_util.py."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from bisect import bisect

# 2. Third party modules
from geopandas import GeoDataFrame
import numpy as np  # noqa: F401 - Used in type hinting

# 3. Aquaveo modules
from xms.constraint import RectilinearGrid2d
from xms.gmi.data.generic_model import Group
from xms.grid.geometry import MultiPolyIntersector
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.gssha.data.bc_util import BcData
from xms.gssha.misc.type_aliases import IntArray

# Type aliases
Pt3d = tuple[float, float, float]
CellIdx = int
ArcInfo = tuple[Group, int, float, float, tuple]  # GMI group, arc point idx, t_val1, t_val2, arc
ArcIx = dict[CellIdx, list[ArcInfo]]  # Arc intersection info, by cells
PointIx = dict[CellIdx, list[Group]]  # Point intersection info
Extents = tuple[list[float, float, float], list[float, float, float]]


def intersect_arcs_with_grid(ugrid: UGrid, on_off_cells: IntArray, stream_data: BcData) -> ArcIx:
    """Intersects the streams with the grid and returns the intersection info.

    Args:
        ugrid: The UGrid.
        on_off_cells: Model on/off cells.
        stream_data: stream data.

    Returns:
          ArcIx, arc intersection info.
    """
    arc_ix: ArcIx = {}
    cell_points, cell_polys = get_cell_polygons(ugrid)
    mpi = MultiPolyIntersector(cell_points, cell_polys, starting_id=0)
    for feature, group in stream_data.feature_bc.items():
        if feature[1] != 'Arc':
            continue
        feature_arc = stream_data.feature_from_id(feature[0], feature[1])
        points = list(feature_arc.geometry.coords)
        for j in range(1, len(points)):
            pt1 = (points[j - 1][0], points[j - 1][1], points[j - 1][2])
            pt2 = (points[j][0], points[j][1], points[j][2])
            cell_idxs, t_vals, _ = mpi.traverse_line_segment(pt1, pt2)
            for k in range(len(cell_idxs) - 1):  # - 1 because the last cell_idx is always -1
                if cell_idxs[k] >= 0 and on_off_cells[cell_idxs[k]]:
                    arc_ix.setdefault(cell_idxs[k], []).append((group, j, t_vals[k], t_vals[k + 1], feature))
    return arc_ix


def intersect_points_with_grid(co_grid: RectilinearGrid2d, on_off_cells: IntArray, stream_data: BcData) -> PointIx:
    """Intersects the streams with the grid and returns the intersection info.

    Because it's a non-rotated, rectilinear grid, we find the grid cell using locations_x and locations_y.

    Args:
        co_grid: The grid.
        on_off_cells: Model on/off cells.
        stream_data: stream data.

    Returns:
          PointIx, point intersection info.
    """
    point_ix: PointIx = {}
    x_locs = [co_grid.origin[0] + x for x in co_grid.locations_x]
    y_locs = [co_grid.origin[1] + y for y in co_grid.locations_y]
    point_types = ['Point', 'Node', 'Vertex']
    for feature, group in stream_data.feature_bc.items():
        if feature[1] not in point_types:
            continue
        feature_point = stream_data.feature_from_id(feature[0], feature[1])
        i = bisect(y_locs, feature_point.geometry.y)
        i = len(y_locs) - i  # Because Orientation.y_decrease
        j = bisect(x_locs, feature_point.geometry.x)
        if 0 <= i < len(y_locs) and 0 <= j < len(x_locs):
            cell_idx = co_grid.get_cell_index_from_ij(i, j)
            if on_off_cells[cell_idx]:
                point_ix.setdefault(cell_idx, []).append(group)
    return point_ix


def get_cell_polygons(ugrid: UGrid) -> tuple[list[Pt3d], list[list[int]]]:
    """Returns the cells as polygons: list of points, list of polygons defined by point indices."""
    cell_points = ugrid.locations
    cell_polys = []
    for cell_idx in range(ugrid.cell_count):
        cell_polys.append(ugrid.get_cell_points(cell_idx))
    return cell_points, cell_polys


def coverage_extents(coverage: GeoDataFrame) -> Extents:
    """Returns the min and max coverage extents.

    Args:
        coverage: The coverage.

    Returns:
        tuple[list[float, float, float], list[float, float, float]]: min, max. min will all be float('inf') and max
        will all be float('-inf') if coverage is None or has no data.
    """
    mn = [float('inf'), float('inf'), float('inf')]
    mx = [float('-inf'), float('-inf'), float('-inf')]
    if coverage is None:
        return mn, mx

    point_types = ['Point', 'Node', 'Vertex']
    points = coverage[coverage['geometry_types'].isin(point_types)]
    for point in points.itertuples():
        # Min
        if point.geometry.x < mn[0]:
            mn[0] = point.geometry.x
        if point.geometry.y < mn[1]:
            mn[1] = point.geometry.y
        if point.geometry.z < mn[2]:
            mn[2] = point.geometry.z
        # Max
        if point.geometry.x > mx[0]:
            mx[0] = point.geometry.x
        if point.geometry.y > mx[1]:
            mx[1] = point.geometry.y
        if point.geometry.z > mx[2]:
            mx[2] = point.geometry.z
    return mn, mx
