"""General purpose geometry functions."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules

# 2. Third party modules
import numba
import numpy as np

# 3. Aquaveo modules

# 4. Local modules


@numba.jit(nopython=True)
def _is_inside_sm(poly, pt) -> int:  # pragma no cover - numba
    """Determines if a point is inside a polygon.

    Polygon can be either clockwise or counterclockwise. Last point must equal first point.

    See this for more info:
    https://stackoverflow.com/questions/36399381/

    Args:
        poly: numpy array of points defining the polygon
        pt: numpy array of x, y coords of the point

    Returns:
        (int): True if the point is inside the polygon
    """
    length = len(poly) - 1
    dy2 = pt[1] - poly[0][1]
    intersections = 0
    ii = 0
    jj = 1

    # short variable names for long if statement below
    p = pt
    pp = poly
    while ii < length:
        dy = dy2
        dy2 = pt[1] - poly[jj][1]

        # consider only lines which are not completely above/bellow/right from the pt
        if dy * dy2 <= 0.0 and (pt[0] >= poly[ii][0] or pt[0] >= poly[jj][0]):
            if dy < 0 or dy2 < 0:  # non-horizontal line
                val = dy * (poly[jj][0] - poly[ii][0]) / (dy - dy2) + poly[ii][0]

                if pt[0] > val:  # if line is left from the pt - the ray moving towards left, will intersect it
                    intersections += 1
                elif pt[0] == val:  # pt on the line
                    return 2

            # pt on upper peak (dy2=dx2=0) or horizontal line (dy=dy2=0 and dx*dx2<=0)
            elif dy2 == 0 and (p[0] == pp[jj][0] or (dy == 0 and (p[0] - pp[ii][0]) * (p[0] - pp[jj][0]) <= 0)):
                return 2

        ii = jj
        jj += 1

    return intersections & 1


@numba.njit(parallel=True)
def _is_inside_sm_parallel(pts, poly, mask) -> np.array:  # pragma no cover - numba
    """Determines if a point is inside a polygon (runs in parallel).

    Polygon can be either clockwise or counterclockwise. Last point must equal first point.

    Args:
        pts: numpy array of points to test
        poly: numpy array of points defining the polygon
        mask: Size of points. If mask is 'off', no calculations for point are done, and it is
         always not in the polygon. Useful for handling model on/off cells (inactive regions).

    Returns:
        (numpy.array): True if the point is inside the polygon
    """
    ln = len(pts)
    # out = np.empty(ln, dtype=numba.boolean)
    out = np.empty(ln, dtype=numba.int64)
    for i in numba.prange(ln):
        # out[i] = is_inside_sm(poly, pts[i])
        # "types.int64() gets rid of "NumbaTypeSafetyWarning: unsafe cast from uint64 to int64. Precision may be lost."
        if mask[i]:
            out[i] = _is_inside_sm(poly, pts[numba.types.int64(i)])
        else:
            out[i] = 0
    return out


def run_parallel_points_in_polygon(points: np.ndarray, polygon: np.ndarray, mask: np.ndarray | None = None) -> np.array:
    """Determines if a point is inside a polygon by calling a parallelized function.

    Polygon can be either clockwise or counterclockwise. Last point must equal first point.

    Args:
        points: points to test
        polygon: points defining the polygon
        mask: Optional, size of points. If mask is 'off', no calculations for point are done, and it is
         always not in the polygon. Useful for handling model on/off cells (inactive regions).

    Returns:
        numpy.array with 1 where point is inside, 0 where point is outside the polygon.
    """
    if not isinstance(points, np.ndarray) or not isinstance(polygon, np.ndarray):
        raise ValueError('run_parallel_points_in_polygon() requires np.ndarray type arguments.')

    if mask is not None and not isinstance(mask, np.ndarray):
        raise ValueError('run_parallel_points_in_polygon() requires np.ndarray type arguments.')

    pts_shape = np.shape(points)
    if len(pts_shape) < 2 or pts_shape[1] < 2:
        raise ValueError('run_parallel_points_in_polygon(): points must be at least 2D.')

    poly_shape = np.shape(polygon)
    if len(poly_shape) < 2 or poly_shape[1] < 2:
        raise ValueError('run_parallel_points_in_polygon(): polygon points must be at least 2D.')

    if poly_shape[0] < 3:
        raise ValueError('run_parallel_points_in_polygon(): polygons must have at least 3 points.')

    if mask is not None and len(mask) != len(points):
        raise ValueError('run_parallel_points_in_polygon(): mask must be the same size as points.')

    if mask is None:
        mask = np.ones(len(points))  # Create a mask with all 'on' if no mask was provided

    return _is_inside_sm_parallel(points, polygon, mask)
