"""MapCellIterator class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data.grid_info import DisEnum, GridInfo
from xms.mf6.mapping.package_builder_base import PackageBuilderBase

# Type aliases
LayerRange = list[int] | tuple[int, int]  # From layer, To layer


class MapCellIterator:
    """Iterates on cell indexes and layers that are being mapped to, handling layer ranges and well screens."""
    def __init__(self, cell_idx, builder: PackageBuilderBase, record):
        """Initializes the class.

        DISU not really supported. For DISU, currently it always just returns the initial cell_idx just one time.

        Args:
            cell_idx: The cell index from the intersection info (0-based).
            builder: The package builder.
            record: A record as a dict.
        """
        if not isinstance(record, dict):
            record = record.as_dict()
        self._cell_idx_lay1 = cell_idx % builder.grid_info.cells_per_layer()  # cell index in layer 1
        self._builder = builder

        self._start_layer = 0
        self._stop_layer = 0  # Will be set to the last layer + 1
        self._counter = 0  # From 0 to (self._stop_layer - self._start_layer)

        self._init_layers(record)

    def __iter__(self):
        """Return self."""
        return self

    def __next__(self) -> int:
        """Return the next cell idx (0-based) and increment the counter, or raise StopIteration."""
        if self._builder.grid_info.dis_enum == DisEnum.DISU and self._counter == 0:
            # Return the initial cell index one time only
            cell_idx = self._cell_idx_lay1
            self._counter += 1
            return cell_idx
        elif self._start_layer > 0 and self._start_layer + self._counter < self._stop_layer:
            cell_idx = self._next_cell_idx()
            self._counter += 1
            return cell_idx
        else:
            raise StopIteration

    def _next_cell_idx(self) -> int:
        """Return the next cell index (0-based) without incrementing the counter."""
        if self._builder.grid_info.dis_enum == DisEnum.DISU:
            cell_idx = self._cell_idx_lay1
        else:
            cell_idx = self._cell_idx_in_layer(self._start_layer + self._counter)
        return cell_idx

    def _cell_idx_in_layer(self, layer: int) -> int:
        """Return the cell index in the layer, based on _cell_idx_lay1.

        Args:
            layer: The 1-based layer.

        Returns:
            See description.
        """
        return self._builder.grid_info.cell_index_from_lay_cell2d(layer - 1, self._cell_idx_lay1, one_based=False)

    def _init_layers(self, record_dict) -> None:
        """Sets self._from_layer and self._to_layer.

        Args:
            record_dict: A record as a dict.
        """
        grid_range = self._builder.grid_layer_range()
        if self._builder.grid_info.dis_enum == DisEnum.DISU:
            layer_range = [-1, -1]  # DISU grids don't have layers so we just use -1
        else:
            layer_range = list(grid_range)
            if _use_screen(record_dict):
                layer_range = self._layer_range_from_screens(record_dict)
            elif _layer_range_exists(record_dict):
                layer_range = [record_dict['FROM_LAYER'], record_dict['TO_LAYER']]
                self._ensure_range_not_inverted(layer_range, record_dict)
                self._ensure_range_in_grid(layer_range, grid_range, record_dict)
                self._handle_override_layers(layer_range, record_dict)
        self._start_layer = layer_range[0]
        self._stop_layer = layer_range[1] + 1

    def _ensure_range_not_inverted(self, layer_range: LayerRange, record_dict) -> None:
        """If range is inverted, swaps it and logs an error.

        Args:
            layer_range: The layer range.
            record_dict: A record as a dict.
        """
        if layer_range[0] > layer_range[1]:
            feature_id = record_dict.get('ID', 'testing_id')
            self._builder.log.error(f'FROM_LAYER less than TO_LAYER for feature with ID: {feature_id}.')
            layer_range[0], layer_range[1] = layer_range[1], layer_range[0]

    def _ensure_range_in_grid(self, layer_range: LayerRange, grid_range: LayerRange, record_dict) -> None:
        """If range is inverted, swaps it and logs an error.

        Args:
            layer_range: The layer range.
            grid_range: Grid layer range.
            record_dict: A record as a dict.
        """
        if layer_range[0] < grid_range[0]:
            feature_id = record_dict.get('ID', 'testing_id')
            self._builder.log.error(f'FROM_LAYER less than grid layer range for feature with ID: {feature_id}.')
            layer_range[0] = grid_range[0]
        if layer_range[1] > grid_range[1]:
            feature_id = record_dict.get('ID', 'testing_id')
            self._builder.log.error(f'TO_LAYER greater than grid layer range for feature with ID: {feature_id}.')
            layer_range[1] = grid_range[1]

    def _handle_override_layers(self, layer_range: LayerRange, record_dict) -> None:
        """Possibly adjust layer range if overriding grid layers.

        Args:
            layer_range: The layer range.
            record_dict: A record as a dict.
        """
        if not self._builder.override_layers:
            return

        # Set from layer to first layer in its range, and to layer to last layer in its range
        if layer_range[0] not in self._builder.override_layers:
            feature_id = record_dict.get('ID', 'testing_id')
            msg = f'FROM_LAYER {layer_range[0]} not in override layers for feature with ID: {feature_id}.'
            self._builder.log.error(msg)
        else:
            layer_range[0] = self._builder.override_layers[layer_range[0]][0]

        if layer_range[1] not in self._builder.override_layers:
            msg = f'TO_LAYER {layer_range[1]} not in override grid layers.'
            self._builder.log.error(msg)
        else:
            layer_range[1] = self._builder.override_layers[layer_range[1]][1]

    def _layer_range_from_screens(self, record_dict) -> LayerRange:
        """Return the layer range using the screened interval.

        Args:
            record_dict:

        Returns:
            See description.
        """
        return layer_range_from_well_screen(
            self._cell_idx_lay1, record_dict['TOP_SCRN'], record_dict['BOT_SCRN'], self._builder.grid_info,
            self._builder.cogrid.get_cell_tops(), self._builder.cogrid.get_cell_bottoms(), self._builder.log,
            record_dict.get('ID', 'testing_id')
        )


def _use_screen(record_dict) -> bool:
    """Return True if the record indicates we should use the screened interval.

    Args:
        record_dict:

    Returns:
        See description.
    """
    return (
        'USE_SCRN' in record_dict and record_dict['USE_SCRN']
    ) and 'TOP_SCRN' in record_dict and 'BOT_SCRN' in record_dict


def _layer_range_exists(record_dict) -> bool:
    """Return True if the record indicates we should use the layer range.

    Args:
        record_dict:

    Returns:
        See description.
    """
    return 'FROM_LAYER' in record_dict and 'TO_LAYER' in record_dict


def layer_range_from_well_screen(
    cell_idx: int,
    top_scrn: float,
    bot_scrn: float,
    grid_info: GridInfo,
    cell_tops,
    cell_bottoms,
    logger=None,
    feature_id=''
) -> tuple[int, int]:
    """Return the layer range intersected by screened interval, or (-1, -1) if screen doesn't intersect.

    Only for DIS and DISV.

    Args:
        cell_idx: Index of cell somewhere in the column of cells where the well screen is located.
        top_scrn: Top of screened interval.
        bot_scrn: Bottom of screened interval.
        grid_info: The GridInfo object.
        cell_tops: Top elevations of all cells.
        cell_bottoms: Bottom elevations of all cells.
        logger: Logger to write errors/warnings to.
        feature_id: Feature object ID to use when writing errors/warnings.

    Returns:
        See description.
    """
    # Check for 0.0 or inverted well screen
    if top_scrn == bot_scrn:
        if logger:
            logger.error(f'Well screen for feature with ID {feature_id} has length 0.0.')
        return -1, -1

    if top_scrn < bot_scrn:
        if logger:
            logger.error(f'Well screen for feature with ID {feature_id} is inverted.')
        return -1, -1

    # Check for a screened interval that is entirely above or below the grid
    cell_idx_top_lay = cell_idx % grid_info.cells_per_layer()
    cell_idx_bot_lay = cell_idx_top_lay + ((grid_info.nlay - 1) * grid_info.cells_per_layer())
    column_top = cell_tops[cell_idx_top_lay]
    column_bottom = cell_bottoms[cell_idx_bot_lay]
    if bot_scrn > column_top:
        if logger:
            logger.error(f'Well screen for feature with ID {feature_id} is above the grid.')
        return -1, -1
    if top_scrn < column_bottom:
        if logger:
            logger.error(f'Well screen for feature with ID {feature_id} is below the grid.')
        return -1, -1

    from_layer, to_layer = 0, 0
    for layer in range(1, grid_info.nlay + 1):
        cell_idx = grid_info.cell_index_from_lay_cell2d(layer - 1, cell_idx_top_lay, one_based=False)
        top = cell_tops[cell_idx]
        bottom = cell_bottoms[cell_idx]
        if top >= top_scrn >= bottom:
            from_layer = layer
        if top >= bot_scrn >= bottom:
            to_layer = layer
        if from_layer and to_layer:
            break

    # Fix intervals that start or stop outside the grid
    if not from_layer and to_layer:  # From above grid, to inside grid
        from_layer = 1
    elif from_layer and not to_layer:  # From inside grid, to below grid
        to_layer = grid_info.nlay
    elif not from_layer and not to_layer and top_scrn > column_top and bot_scrn < column_bottom:
        # From above grid, to below grid
        from_layer, to_layer = 1, grid_info.nlay
    return from_layer, to_layer
