"""Builds a GSSHA co_grid."""

__copyright__ = '(C) Copyright Aquaveo 2024'
__license__ = 'All rights reserved'

# 1. Standard Python modules
from logging import Logger
import math

# 2. Third party modules
from geopandas import GeoDataFrame
import numpy as np
import numpy.typing as npt

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.constraint import Grid, Orientation, RectilinearGridBuilder
from xms.gdal.rasters import RasterInput
from xms.gmi.data.generic_model import Group
from xms.grid.geometry import geometry
from xms.grid.ugrid import UGrid
from xms.tool.algorithms.geometry import geometry as tool_geom
from xms.tool.utilities.coverage_conversion import get_polygon_point_lists

# 4. Local modules
from xms.gssha.components import dmi_util
from xms.gssha.components.bc_coverage_component import BcCoverageComponent
from xms.gssha.data import bc_util
from xms.gssha.data.bc_generic_model import ChannelType
from xms.gssha.data.bc_util import BcData
from xms.gssha.file_io import io_util
from xms.gssha.mapping import map_util
from xms.gssha.mapping.map_util import ArcIx
from xms.gssha.misc.type_aliases import IntArray, Pt3dArray
from xms.gssha.tools import tool_util


def build_co_grid(
    query: Query, bc_cov: GeoDataFrame, cell_size: float, raster: RasterInput | None, vertical_units: str,
    use_streams: bool, logger: Logger
) -> Grid:
    """Builds and returns the grid.

    Args:
        query: Object for communicating with XMS.
        bc_cov: BC coverage.
        cell_size: Length of one edge of a cell.
        raster: The elevation raster, if interpolating cell elevatons from a raster.
        vertical_units: Vertical units.
        use_streams: True to use stream elevations where the streams intersect the grid.
        logger: The logger.

    Returns:
        The grid.
    """
    logger.info('Getting coverage extents...')
    mn, mx = map_util.coverage_extents(bc_cov)

    logger.info('Building the grid...')

    # Find num cells in x and y by dividing x and y extents by the cell size and rounding up to nearest int.
    x_size = (mx[0] - mn[0])
    y_size = (mx[1] - mn[1])
    num_cells_x = num_cells_from_cell_size(x_size, cell_size)
    num_cells_y = num_cells_from_cell_size(y_size, cell_size)

    builder = RectilinearGridBuilder()
    builder.is_2d_grid = True
    builder.origin = (mn[0], mn[1])
    builder.orientation = (Orientation.y_decrease, Orientation.x_increase)
    builder.set_square_xy_locations(num_cells_x + 1, num_cells_y + 1, cell_size)
    co_grid = builder.build_grid()
    ugrid = co_grid.ugrid

    logger.info('Getting stream data')
    bc_comp = dmi_util.get_bc_coverage_component(bc_cov.attrs['uuid'], query) if bc_cov is not None else None
    stream_data = bc_util.get_stream_data(query, bc_cov, bc_comp)
    on_off_cells_with_all_on = np.ones(ugrid.cell_count, dtype=int)
    arc_ix = map_util.intersect_arcs_with_grid(ugrid, on_off_cells_with_all_on, stream_data)

    logger.info('Disabling cells outside the domain...')
    on_off_cells = _compute_on_off_cells(ugrid, bc_cov, arc_ix)
    co_grid.model_on_off_cells = on_off_cells

    logger.info('Computing cell elevations...')
    cell_elevations = _cell_elevations_from_raster(ugrid, on_off_cells, raster, vertical_units)
    if use_streams:
        _add_stream_elevations(arc_ix, bc_comp, cell_elevations, stream_data)
    co_grid.cell_elevations = cell_elevations

    return co_grid


def _cell_elevations_from_raster(
    ugrid: UGrid, on_off_cells: IntArray, raster: RasterInput | None, vertical_units: str
) -> 'npt.NDArray':
    """Returns an elevation for every cell: interpolated from the raster if it exists, or all 0.0."""
    if not raster:
        return [0.0] * ugrid.cell_count

    cell_centers = tool_util.compute_cell_centers(ugrid)
    return tool_util.raster_values_at_cells(cell_centers, on_off_cells, raster, vertical_units)


def _add_stream_elevations(
    arc_ix: ArcIx, bc_comp: BcCoverageComponent, cell_elevations: 'npt.NDArray', stream_data: BcData
) -> None:
    """Adjusts elevations by setting elevation to stream elev plus stream depth."""
    # Get all arc points once
    arc_points: dict[tuple, list[tuple]] = {}
    for _cell_idx, list_nodes in arc_ix.items():
        for _group, _arc_segment, _t_val1, _t_val2, arc in list_nodes:
            if arc not in arc_points:
                feature_arc = stream_data.feature_from_id(arc[0], arc[1])
                arc_pts = list(feature_arc.geometry.coords)
                # Convert to tuples to use geometry.distance_2d() later
                arc_points[arc] = [(pt[0], pt[1], pt[2]) for pt in arc_pts]

    for cell_idx, list_nodes in arc_ix.items():
        cell_zs = []
        total_weight = 0.0
        depth = 0.0
        for group, arc_pt_idx, t_val1, t_val2, arc in list_nodes:
            # Get stream depth
            if group.parameter('channel_type').value == ChannelType.TRAPEZOIDAL:
                depth = group.parameter('bankfull_depth').value
            else:
                depth = _get_cross_section_depth(group, bc_comp)

            # Get arc segment end points
            arc_pts = arc_points[arc]
            seg_pt1, seg_pt2 = arc_pts[arc_pt_idx - 1], arc_pts[arc_pt_idx]

            # Get points where segment intersects the cell
            vec = (seg_pt2[0] - seg_pt1[0], seg_pt2[1] - seg_pt1[1], seg_pt2[2] - seg_pt1[2])
            ix_pt1 = (seg_pt1[0] + t_val1 * vec[0], seg_pt1[1] + t_val1 * vec[1], seg_pt1[2] + t_val1 * vec[2])
            ix_pt2 = (seg_pt1[0] + t_val2 * vec[0], seg_pt1[1] + t_val2 * vec[1], seg_pt1[2] + t_val2 * vec[2])

            # Get the weight, which is the distance between intersection points
            weight = geometry.distance_2d(ix_pt1, ix_pt2)
            total_weight += weight

            # Get average z between intersection points
            avg_z = (ix_pt1[2] + ix_pt2[2]) / 2.
            cell_zs.append((avg_z, weight))

        # Get weighted average
        z = 0.0
        for cell_z, weight in cell_zs:
            z += cell_z * (weight / total_weight)
        cell_elevations[cell_idx] = z + depth


def _get_cross_section_depth(group: Group, bc_comp: BcCoverageComponent) -> float:
    """Computes the depth from the cross-section by taking the minimum of the y values on the two ends."""
    xy_id = group.parameter('cross_section').value
    _x_values, y_values = bc_comp.data.get_curve(xy_id, False)
    if y_values is not None and len(y_values) > 1:
        min_ends = min(y_values[0], y_values[-1])
        min_middle = min(y_values[1:-1])
        return min_ends - min_middle
    return 0.0


def num_cells_from_cell_size(distance: float, cell_size: float) -> int:
    """Returns the number of cells in the x or y direction."""
    return int(math.ceil(distance / cell_size))


def _compute_on_off_cells(ugrid: UGrid, bc_cov: GeoDataFrame, arc_ix: ArcIx) -> IntArray:
    """Computes and returns the array of model on/off cells."""
    cell_centers = tool_util.compute_cell_centers(ugrid)
    poly_mask = _get_points_in_polygon_mask(cell_centers, bc_cov)
    stream_mask = io_util.get_arcs_mask(ugrid, arc_ix)
    mask = np.logical_or(poly_mask, stream_mask)
    return mask.astype(np.int64)


def _get_points_in_polygon_mask(cell_centers: Pt3dArray, bc_cov: GeoDataFrame) -> IntArray:
    """Updates mask by turning on cells if cell center is inside any polygon (and not inside a polygon hole)."""
    mask = np.zeros(len(cell_centers), dtype=int)
    polygons = bc_cov[bc_cov['geometry_types'] == 'Polygon']
    for polygon in polygons.itertuples():
        poly_point_lists = get_polygon_point_lists(polygon)
        in_poly = tool_geom.run_parallel_points_in_polygon(cell_centers, np.asarray(poly_point_lists[0]))

        for inner_poly in poly_point_lists[1:]:
            in_hole = tool_geom.run_parallel_points_in_polygon(cell_centers, np.asarray(inner_poly))
            in_poly = np.logical_xor(in_poly, in_hole)

        mask = np.logical_or(mask, in_poly)
    return mask.astype(np.int64)
