"""GridIntersector class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from dataclasses import dataclass, field

# 2. Third party modules
from flopy.discretization.vertexgrid import VertexGrid
from flopy.utils.gridintersect import GridIntersect
import numpy as np
import shapefile
from shapely.geometry import LineString, Point, Polygon

# 3. Aquaveo modules
from xms.constraint import Grid
from xms.constraint.modflow import get_stacked_grid_2d_topology
from xms.coverage.grid.ugrid_mapper import UgridMapper
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.mf6.components import dis_builder
from xms.mf6.data.dis_data_base import DisDataBase
from xms.mf6.geom import geom, shapefile_geom
from xms.mf6.misc import log_util, util


@dataclass
class IxInfo:
    """Intersection info."""
    points: list = field(default_factory=list[np.recarray])  # size of num points
    arcs: list = field(default_factory=list[np.recarray])  # size of num arcs
    arc_t_vals: list = field(default_factory=list)  # size of num arcs
    arc_hfb_cell_faces: list = field(default_factory=list)
    polys: list = field(default_factory=list[np.recarray])  # size of num polys
    poly_t_vals: list = field(default_factory=list[list[float]])  # [num polys][num intersected cells]

    def get_list(self, feature_type: str) -> list:
        """Return the intersection list corresponding to the feature type.

        Args:
            feature_type: 'points', 'arcs', or 'polygons'.

        Returns:
            See description.
        """
        if feature_type == 'points':
            return self.points
        elif feature_type == 'arcs':
            return self.arcs
        elif feature_type == 'polygons':
            return self.polys
        else:
            raise ValueError('feature type not supported')

    def set_list(self, feature_type: str, ix_list: list) -> None:
        """Set the intersection list corresponding to the feature type.

        Args:
            feature_type: 'points', 'arcs', or 'polygons'.
            ix_list: List of intersections.
        """
        if feature_type == 'points':
            self.points = ix_list
        elif feature_type == 'arcs':
            self.arcs = ix_list
        elif feature_type == 'polygons':
            self.polys = ix_list
        else:
            raise ValueError('feature type not supported')


class GridIntersector:
    """Intersects a shapefile with a grid."""
    def __init__(self, cogrid: Grid, hfb: bool = False):
        """Initializes the class.

        Args:
            cogrid (xms.constraint.Grid | None): The ugrid.
            hfb (bool): If true, we find cells on both sides of the arc.
        """
        self._cogrid = cogrid
        self._ugrid = cogrid.ugrid if cogrid else None  # UGrid of the cogrid, so we only get it once
        self._hfb = hfb

        self._ugrid_mapper = None
        self._cell_areas: list[float] = []

    def intersect(self, dis: DisDataBase, shapefile_names) -> IxInfo:
        """Intersects the shapefile with the MODFLOW grid.

        Args:
            dis: The DIS, DISU, or DISV package.
            shapefile_names (dict): Dict of shapefile names.

        Returns:
            ix_info: The intersection info.
        """
        logger = log_util.get_logger()
        logger.info('Intersecting shapefile with grid')
        _fgrid, intersector = create_flopy_grid_and_intersector(dis, self._cogrid, self._ugrid)
        ix_info = IxInfo()
        if 'points' in shapefile_names:
            points = self._shapely_points_from_shapefile(shapefile_names['points'], dis.ftype)
            for point in points:
                ix_info.points.append(intersector.intersect(point))

        if 'arcs' in shapefile_names:
            linestrings = self._shapely_linestrings_from_shapefile(shapefile_names['arcs'])
            # inactives_exist, idomain = self._inactives_exist(dis)
            for linestring in linestrings:
                # We set sort_by_cellid to True for more predictable results, even though a few lines later
                # we sort by t_values. If sort_by_cellid is True, only top layer cells are returned. Otherwise
                # you may get some in the top layer, some in other layers.
                rec = intersector.intersect(linestring, sort_by_cellid=True)
                polyline = self._polyline_from_linestring(linestring)
                t_values = self._compute_arc_t_values(polyline, rec)
                sorted_t_values = self._sort_by_t_values(rec, t_values)
                ix_info.arcs.append(rec)
                ix_info.arc_t_vals.append(sorted_t_values)  # Add t_values
                if self._hfb:
                    self._add_hfb_cells(polyline, ix_info)

        if 'polygons' in shapefile_names:
            polygons = self._shapely_polygons_from_shapefile(shapefile_names['polygons'])
            for polygon in polygons:
                ix_info.polys.append(intersector.intersect(polygon))
        return ix_info

    def _add_hfb_cells(self, polyline, ix_info: IxInfo):
        """Adds the cells on both sides of the arcs.

        Args:
            polyline (list(tuple(float))): The x,y points defining the linestring.
            ix_info: The intersection info.
        """
        if not self._ugrid_mapper:
            logger = log_util.get_logger()
            self._ugrid_mapper = UgridMapper(self._cogrid, cell_materials=None, logger=logger)
        hfb_cell_faces = self._ugrid_mapper.get_vertical_faces_intersected_by_arc(polyline)
        if hfb_cell_faces:
            ix_info.arc_hfb_cell_faces.append(hfb_cell_faces[0])  # Just top layer (consistent with everything else)
        else:
            ix_info.arc_hfb_cell_faces.append([])

    def _sort_by_t_values(self, rec, t_values):
        """Sorts the arc intersection recarray (cellids, lengths, vertices...) by t_values.

        Args:
            rec: numpy recarray from flopy.
            t_values: List of t_values that is parallel to recarray vertices.

        Return:
            Sorted t_value list.
        """
        first_t_values = [item[0] for item in t_values]  # First t_value for each cell
        idx = list(range(len(t_values)))  # List of ints: 0, 1, 2, 3 ...
        idx = util.sort_list_by_other_list(idx, first_t_values)

        rec['cellids'] = np.array(rec['cellids'])[idx]
        rec['vertices'] = np.array(rec['vertices'])[idx]
        rec['lengths'] = np.array(rec['lengths'])[idx]
        rec['ixshapes'] = np.array(rec['ixshapes'])[idx]

        # Also return a sorted version of the t_values
        return util.sort_list_by_other_list(t_values, idx)

    def _shapely_points_from_shapefile(self, filename, dis_ftype):
        """Returns a shapely MultiPoint by reading the points in the shapefile.

        Args:
            filename (str): Filepath of shapefile.

        Returns:
            (list(Point)): List of shapely Point.
        """
        points = []
        with shapefile.Reader(filename) as reader:
            for shape in reader.iterShapes():
                # if dis_ftype == 'DIS6':  # Don't include z for DIS or flopy will use it to find the grid layer
                #     point = Point(shape.points[0][0], shape.points[0][1])
                # else:
                #     point = Point(shape.points[0][0], shape.points[0][1], shape.z[0])
                # Actually, we never need z
                point = Point(shape.points[0][0], shape.points[0][1])
                points.append(point)
        return points

    def _shapely_linestrings_from_shapefile(self, filename):
        """Returns a shapely MultiLineString by reading the polylines in the shapefile.

        Args:
            filename (str): Filepath of shapefile.

        Returns:
            (MultiLineString): The shapely shape.
        """
        linestrings = []
        with shapefile.Reader(filename) as reader:
            for shape in reader.iterShapes():
                points = shapefile_geom.xyz_tuples_from_shape(shape)
                linestring = LineString(points)
                linestrings.append(linestring)
        return linestrings

    def _shapely_polygons_from_shapefile(self, filename):
        """Returns a shapely Polygon by reading the polygon in the shapefile.

        Args:
            filename (str): Filepath of shapefile.

        Returns:
            (list(Polygon)): List of shapely Polygon.
        """
        polygons = []
        with shapefile.Reader(filename) as reader:
            for shape in reader.iterShapes():
                loops = shapefile_geom.geom_from_shape(shape)
                if len(shape.parts) == 1:
                    polygon = Polygon(shell=loops[0])
                else:
                    polygon = Polygon(shell=loops[0], holes=loops[1:])
                polygons.append(polygon)
        return polygons

    def _compute_arc_t_values(self, polyline, rec):
        """Returns t values associated with intersected vertex points.

        Args:
            polyline (list(tuple(float))): The x,y points defining the linestring.
            rec (recarray): Recarray with intersection results.

        Returns:
            (list(list)): Size of rec num rows (i.e. num cells). Inner lists are size of num vertices per cell.
        """
        tol = 1e-6  # May need to be smarter about this
        stationing = geom.polyline_stationing(polyline)  # Distance of each point on polyline (e.g. [0., 5., 8.])
        rec_t_values = []  # List of t_value lists, size of rec num rows
        for cell_verts in rec.vertices:  # Iterate through vertices column
            cell_t_values = []  # List of t_values associated with each intersected vertex for this cell
            self._recurse_get_t_values(cell_verts, polyline, stationing, tol, cell_t_values)
            rec_t_values.append(cell_t_values)
        return rec_t_values

    def _polyline_from_linestring(self, linestring):
        """Returns a list of tuples defining a polyline given a Shapely linestring.

        Args:
            linestring (shapely.geometry.LineString): The linestring from the shapefile.

        Returns:
            polyline (list(tuple(float))): The x,y points defining the linestring.
        """
        x, y = linestring.xy
        polyline = list(zip(x, y))  # List of (x, y) tuples
        return polyline

    def _recurse_get_t_values(self, sequence, polyline, stationing, tol, cell_t_values):
        """Recursive function to get the t_values for all the vertices for a cell.

        Recursive because (surprise) we don't know ahead of time the depth of the nested tuples. When a line
        reenters a cell, the nested tuples go deeper. Thus we recurse to get down to the actual vertices.
        Also, sometimes (?!) we get them as numpy arrays and sometimes as just tuples.

        Args:
            sequence: Either an xyz tuple, or tuple of xyz tuples (or a tuple of tuples of xyz tuples etc).
            polyline (list(tuple(float))): The x,y points defining the linestring.
            stationing (list(float): list of distances along linestring of each point, starting at 0.0 for the first.
            tol (float): Tolerance used when finding a vertex on the line segment.
            cell_t_values (list(float): List of the t_values for the current cell.
        """
        if sequence is None:
            return

        is_vert = False
        for element in sequence:
            if util.is_non_string_sequence(element):
                self._recurse_get_t_values(element, polyline, stationing, tol, cell_t_values)
            else:
                is_vert = True
                break

        if is_vert:
            self._get_t_value(sequence, polyline, stationing, tol, cell_t_values)

    def _get_t_value(self, vertex, polyline, stationing, tol, cell_t_values):
        """Appends the t_value at the vertex along the polyline.

        Args:
            vertex: An xyz tuple
            polyline (list(tuple(float))): The x,y points defining the linestring.
            stationing (list(float): list of distances along linestring of each point, starting at 0.0 for the first.
            tol (float): Tolerance used when finding a vertex on the line segment.
            cell_t_values (list(float): List of the t_values for the current cell.
        """
        segment = self._segment_from_vertex(vertex, tol, polyline)
        if segment > -1:
            t_value = self._compute_t_value(stationing[segment], stationing[-1], polyline[segment], vertex)
            cell_t_values.append(t_value)
        else:
            logger = log_util.get_logger()
            logger.error('Could not find intersection point on linestring. Potential tolerance problem.')
            cell_t_values.append(0.0)  # Put something in so we don't crash later

    def _compute_t_value(self, previous_length, total_length, segment_endpoint, vert):
        """Returns the normalized distance between 0.0 and 1.0 from start of linestring to the intersected vertex.

        Args:
            previous_length (float): Length of all segments previous to the current.
            total_length (float): Total length of all segments.
            segment_endpoint (x, y): End of the previous segment.
            vert (x, y, (z)?): Vertex intersection point.

        Returns:
            (float): See description.
        """
        d = previous_length + geom.distance_2d(segment_endpoint, vert)
        return d / total_length

    def _segment_from_vertex(self, vert, tol, polyline):
        """Returns the index (0-based) of the segment that the vertex is on.

        This could be made more efficient if we remembered the last segment found and started from there instead of
        from the beginning every time.

        Args:
            vert (x, y, (z)?): Vertex intersection point.
            tol (float): Tolerance used when finding a vertex on the line segment.
            polyline (list(tuple(float))): The x,y points defining the linestring.

        Returns:
            (int): See description.
        """
        segment = -1
        for i in range(len(polyline) - 1):
            if geom.on_line_and_between_endpoints_2d(polyline[i], polyline[i + 1], vert, tol):
                segment = i
                break
        return segment


def create_flopy_grid_and_intersector(dis: DisDataBase, cogrid: Grid, ugrid: UGrid) -> tuple[VertexGrid, GridIntersect]:
    """Creates the flopy grid and the GridIntersector.

    See https://github.com/modflowpy/flopy/blob/develop/examples/Notebooks/flopy3_grid_intersection_demo.ipynb

    Args:
        dis: The DIS* package.
        cogrid: The constrained grid.
        ugrid: The ugrid of the cogrid.

    Returns:
        (tuple): tuple containing:
            - (VertexGrid): The flopy grid.
            - (flopy.utils.gridintersect.GridIntersect): The GridIntersector.
    """
    idomain = dis.nparray_from_array('IDOMAIN')  # Send it in although at present flopy appears to not use this

    # We decided to always use flopy's VertexGrid type of intersection, even with DIS6, because the other one does not
    # handle real world projections good enough, and because the cell IDs were different, and so that there is only
    # one way.
    if dis.ftype in {'DIS6', 'DISV6'}:
        locations, cells_2d = get_stacked_grid_2d_topology(ugrid)
    else:
        locations, cells_2d = dis_builder.get_disu_topology(ugrid)
    return _create_vertex_grid_and_intersector(cells_2d, idomain, locations, cogrid, ugrid)


def _create_vertex_grid_and_intersector(cells_2d, idomain, locations, cogrid: Grid,
                                        ugrid: UGrid) -> tuple[VertexGrid, GridIntersect]:
    """Creates and returns VertexGrid and GridIntersect objects with given data.

    Args:
        cells_2d (list[tuple[int]]): Cell 2D stream, 0-based
        idomain: IDOMAIN
        locations (list[list[float]]): XY vertices
        cogrid: The constrained grid.
        ugrid: The ugrid of the cogrid.

    Returns:
        (tuple[VertexGrid, GridIntersect]): See description.
    """
    vertices = dis_builder.vertices_list_from_locations(locations, one_based=False)
    cell_centers = dis_builder.get_cell_centers2d(cogrid, ugrid)
    cell2d = dis_builder.cell2d_list_from_grid_data(cells_2d, cell_centers, one_based=False)
    vtx_grid = VertexGrid(vertices=vertices, cell2d=cell2d, idomain=idomain)
    intersector = GridIntersect(vtx_grid)
    return vtx_grid, intersector
