"""CellPolygonCalculator class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import Any

# 2. Third party modules

# 3. Aquaveo modules
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.mapping.map_cell_iterator import MapCellIterator
from xms.mf6.mapping.package_builder_base import should_skip_record


class CellPolygonCalculator:
    """Performs calculations with cells and polygons."""
    def __init__(self, builder, ix_recs, records, att_type=''):
        """Initializer.

        Args:
            builder: The package builder.
            ix_recs (list): The intersection info.
            records: The records from the shapefile.
            att_type (str): String used in att table for the attribute Type column (i.e. 'well', 'drain', 'river' etc)
        """
        self._builder = builder
        self._ix_recs = ix_recs
        self._records = records
        self._att_type = att_type

    def compute_cell_polygons(self, cell_count: int, highest_active_cell=True) -> list[int]:
        """Computes and returns the array, size of cell_count, that tells which polygon goes with each cell.

        Returns:
            (list): The polygon index with the largest intersecting area for each cell.
        """
        max_areas = [0.0] * cell_count  # largest intersecting area for each cell
        cell_polys = [-1] * cell_count  # polygon index with largest intersecting area for each cell

        # Loop through polygons
        for i in range(len(self._records)):
            if should_skip_record(self._ix_recs[i], self._records[i], self._att_type):
                continue  # Shape intersected no cells or the feature att type doesn't match. Skip it.

            self._apply_polygon_areas(i, max_areas, cell_polys, self._records[i], highest_active_cell)
        return cell_polys

    def get_cell_polys_and_areas(self) -> tuple[dict[Any, list], dict[Any, list]]:
        """Returns dicts containing the polygons overlapping each cell, and the areas of overlap.

        Returns:
            (tuple): tuple containing:
                - cell_polys (dict): cell_idx -> list(poly_idx)
                - cell_areas (dict): cell_idx -> list(area)
        """
        cell_polys = {}  # cell_idx -> list(poly_idx)
        cell_areas = {}  # cell_idx -> list(area)
        for poly_idx in range(len(self._records)):
            if self._records[poly_idx]['Type'] != self._att_type:
                continue
            record = self._records[poly_idx]
            cellidxs = self._ix_recs[poly_idx].cellids.tolist()
            areas = self._ix_recs[poly_idx].areas.tolist()
            cell_idxs, areas = self._get_cells_and_areas_for_layer_range(
                cellidxs, areas, record, highest_active_cell=False
            )

            for i, cellid in enumerate(cell_idxs):
                if cellid not in cell_polys:
                    cell_polys[cellid] = []
                    cell_areas[cellid] = []
                cell_polys[cellid].append(poly_idx)
                cell_areas[cellid].append(areas[i])
        return cell_areas, cell_polys

    def _apply_polygon_areas(self, i, max_areas, cell_polys, record, highest_active_cell):
        """Apply the current polygon's areas to the arrays, which may change max_areas and therefore cell_polys.

        Args:
            i (int): Index of the current polygon shape.
            max_areas (list(float)): List with the current max intersecting area for each cell.
            cell_polys (list(int)): List, size of cell_count, with the polygon indices for each cell.
            record: A record from the shapefile.
            highest_active_cell (bool): True if we only want to include cells in the top layer (not for DISU).
        """
        cellidxs = self._ix_recs[i].cellids.tolist()
        areas = self._ix_recs[i].areas.tolist()
        cell_idxs, areas = self._get_cells_and_areas_for_layer_range(cellidxs, areas, record, highest_active_cell)

        for j, cell_idx in enumerate(cell_idxs):
            if areas[j] > max_areas[cell_idx]:
                max_areas[cell_idx] = areas[j]
                cell_polys[cell_idx] = i

    def _get_cells_and_areas_for_layer_range(self, cell_idxs, areas, record, highest_active_cell):
        """Given the intersected cells and the layer range, returns all the cells and areas in the layer range.

        This function assumes that, for DIS and DISV, the cells are in order from higher to lower layers. This is
        true for flopy.GridIntersect class.

        Args:
            areas (list(float): Intersected areas
            cell_idxs (list): Intersected cell indices, 0-based.
            record: A record from the shapefile.

        Returns:
            (tuple): tuple containing:
                - (list(float)): cell_idxs
                - (list(float)): areas
        """
        new_cell_idxs_areas = {}
        column_id_set = set()  # Set of columns that we've added
        ugrid = None if self._builder.grid_info.dis_enum != DisEnum.DISU else self._builder.ugrid
        idomain = self._builder._idomain  # for short

        if self._builder.grid_info.dis_enum == DisEnum.DISU:
            for j, cell_idx in enumerate(cell_idxs):
                if not highest_active_cell:
                    new_cell_idxs_areas[cell_idx] = areas[j]
                else:
                    if not idomain or idomain[cell_idx] > 0:
                        # Look at top faces and if one of them doesn't have an adjacent cell, or it does but
                        # that cell is inactive, add the current cell.
                        face_count = ugrid.get_cell_3d_face_count(cell_idx)
                        for face in range(face_count):
                            orientation = ugrid.get_cell_3d_face_orientation(cell_idx, face)
                            if orientation == UGrid.face_orientation_enum.ORIENTATION_TOP:
                                adjcell = ugrid.get_cell_3d_face_adjacent_cell(cell_idx, face)
                                if adjcell == -1 or (idomain and idomain[adjcell] == 0):
                                    # If the cell above (adjcell) is a pass through (idomain == -1), we don't add
                                    # the current cell because we assume there's some cell above adjcell that is
                                    # the highest active. May be a wrong assumption.
                                    new_cell_idxs_areas[cell_idx] = areas[j]
                                    break
        else:
            for j, cell_idx in enumerate(cell_idxs):
                map_cell_iterator = MapCellIterator(cell_idx, self._builder, record)
                for cell_idx_layer in map_cell_iterator:
                    if not highest_active_cell:
                        new_cell_idxs_areas[cell_idx_layer] = areas[j]
                    else:
                        column_id = self._column_id(cell_idx)
                        # The following assumes the cellids are in order from higher to lower layers, which
                        # is true for flopy.GridIntersect
                        if column_id not in column_id_set and (not idomain or idomain[cell_idx] > 0):
                            new_cell_idxs_areas[cell_idx] = areas[j]
                            column_id_set.add(column_id)

        # Convert map to lists
        new_cell_idxs = [cell_idx for cell_idx in new_cell_idxs_areas.keys()]
        new_areas = [area for area in new_cell_idxs_areas.values()]
        return new_cell_idxs, new_areas

    def _column_id(self, cell_idx: int):
        """Return an identifier for the column where the cell_idx is located.

        Args:
            cell_idx: Cell index

        Returns:
            See description.
        """
        return cell_idx % self._builder.grid_info.cells_per_layer()
