"""CellAdderWel class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import sys

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.components.arrays_to_datasets import values_from_array
from xms.mf6.mapping.cell_adder import CellAdder
from xms.mf6.mapping.map_cell_iterator import MapCellIterator
from xms.mf6.misc import log_util

# Constants
NO_INTERSECT = 'NO_INTERSECT'
HK_ZERO = 'HK_ZERO'
ZERO_LENGTH_SCREEN = 'ZERO_LENGTH_SCREEN'


class CellAdderWel(CellAdder):
    """Adds a cell BC for the WEL package."""
    def __init__(self, builder, package, tops, bottoms):
        """Initializes the class.

        Args:
            builder: The package builder.
            package (WelData): The WEL package
            tops (list[float]): List of cell top elevations
            bottoms (list[float]): List of cell bottom elevations
        """
        super().__init__(builder)
        self._log = log_util.get_logger()
        self._tops = tops
        self._bottoms = bottoms
        self._total_t = 0.0
        self._tvals: dict[int, float] = {}  # cell index -> percent of total overlap
        self._k = self._find_npf_k(package)

    def _find_npf_k(self, package) -> list[float] | None:
        """Load the NPF K values for mapping with a well screen.

        Args:
            package (WelData): The WEL package

        Returns:
            (list[float]): List of K values
        """
        npf = package.model.packages_from_ftype('NPF6') if package.model else []
        if npf:
            griddata = npf[0].block('GRIDDATA')
            return values_from_array(griddata.array('K'), self._builder.grid_info)
        return None

    def _compute_q(self, record) -> float | None:
        """Get Q for a cell either from the specified flow or computed from the well screen.

        Args:
            record (dict): The record

        Returns:
            (float | None): See description
        """
        use_scrn = record.get('USE_SCRN', False)
        partition = record.get('PART_WEL', False)
        q = record.get('Q', 0.0)
        if (use_scrn or partition) and self._k is not None:
            t_val = self._tvals.get(self._cell_idx1)
            if t_val is None:
                return None  # This cell is excluded from the model
            q = t_val / self._total_t * q if t_val is not None and self._total_t else 0.0
        return q

    def _default_screen(self, cell_idx: int, record) -> tuple[float, float]:
        """Get the default screen for a well from top layer cell top to bottom layer cell bottom.

        Args:
            cell_idx: Cell index.
            record: A record from the shapefile.

        Returns:
            (tuple[float, float]): See description
        """
        top_scrn = ~sys.maxsize
        bot_scrn = sys.maxsize
        map_cell_iterator = MapCellIterator(cell_idx, self._builder, record)
        for cell_idx in map_cell_iterator:
            top_scrn = max(top_scrn, self._tops[cell_idx])
            bot_scrn = min(bot_scrn, self._bottoms[cell_idx])
        return top_scrn, bot_scrn

    def _column_values_from_dicts(self, row_or_record, feature_id: int) -> list:
        """Returns the list of column values (a row) given row_or_record and map_import_info.

        Args:
            row_or_record: Can be either a record, if reading the shapefile, or a row, if reading the CSV file.
            feature_id (int): ID of current feature object.

        Returns:
            (list): See description.
        """
        column_values = []
        for key in self._builder.map_import_info.keys():
            if key == 'Q':
                cell_q = self._compute_q(row_or_record)
                if cell_q is not None:
                    column_values.append(cell_q)
                else:
                    return []  # This cell is excluded from the model
            else:
                column_values.append(self._column_value(feature_id, key, row_or_record, None))
        return column_values

    def compute_cell_intervals(self, cell_idx: int, record) -> dict[int, tuple[float, float]]:
        """Compute top and bottom interval elevation for each cell included in layer range or screened interval.

        Top can be screen top or cell top. Likewise, bottom can be screen bottom or cell bottom.

        Args:
            cell_idx: Cell index.
            record: Shapefile record for the current arc.
        """
        if not isinstance(record, dict):
            record = record.as_dict()
        intervals: dict[int, tuple[float, float]] = {}
        use_screen = 'USE_SCRN' in record and record['USE_SCRN']
        top_scrn, bot_scrn = None, None
        if not use_screen:  # Compute a default screen from top layer cell top to bottom layer cell bottom
            top_scrn, bot_scrn = self._default_screen(cell_idx, record)
        elif use_screen:  # Get the specified screen top and bottom
            top_scrn = record['TOP_SCRN']
            bot_scrn = record['BOT_SCRN']
        if top_scrn is not None and bot_scrn is not None and top_scrn <= bot_scrn:
            self.log_warning(ZERO_LENGTH_SCREEN, record["ID"])
            return {}

        map_cell_iterator = MapCellIterator(cell_idx, self._builder, record)
        for cell_idx in map_cell_iterator:
            top = self._tops[cell_idx]
            bottom = self._bottoms[cell_idx]
            if top_scrn < bottom or bot_scrn > top:
                continue  # This layer is outside the range of the screen
            intervals[cell_idx] = min(top_scrn, top), max(bot_scrn, bottom)
        return intervals

    def compute_well_screen_overlaps(self, cell_idx: int, record: dict):
        """Computes the well screen overlaps for the layers of a cell.

        Args:
            cell_idx: Cell index.
            record: Shapefile record for the current arc.
        """
        intervals = self.compute_cell_intervals(cell_idx, record)
        if not intervals:
            self.log_warning(NO_INTERSECT, record["ID"])
            return

        self._tvals = {}
        self._total_t = 0.0
        partition = record.get('PART_WEL', False)
        if self._k is None or (not record['USE_SCRN'] and not partition):
            return

        have_overlap = have_k = False
        for cell_idx, (top, bottom) in intervals.items():
            overlap = (top - bottom)
            if overlap > 0.0:
                have_overlap = True
            cell_k = self._k[cell_idx]
            if cell_k > 0.0:
                have_k = True
            self._tvals[cell_idx] = overlap * cell_k
        self._total_t = sum(self._tvals.values())
        if not have_k:
            self.log_warning(HK_ZERO, record["ID"])
        if not have_overlap:
            self._tvals = {}
            self._total_t = 0.0
            self.log_warning(NO_INTERSECT, record["ID"])

    def log_warning(self, warning_id: str, feature_id: int) -> None:
        """Logs a warning.

        Args:
            warning_id: The warning we want logged.
            feature_id: ID of the feature object.
        """
        if warning_id == NO_INTERSECT:
            msg = (
                f'Well point with feature ID {feature_id} has been excluded from the MODFLOW model because the '
                'well does not intersect any cells in the MODFLOW grid.'
            )
        elif warning_id == HK_ZERO:
            msg = (
                f'The flow rate for well point with feature ID {feature_id} has been set to zero because the '
                'hydraulic conductivities are all zero in the intersected cells in the MODFLOW grid.'
            )
        elif warning_id == ZERO_LENGTH_SCREEN:
            msg = (
                f'Well point with feature ID {feature_id} has been excluded from the MODFLOW model because the well '
                'screen has a length less than or equal to zero.'
            )
        else:
            raise ValueError('warning not handled')
        self._log.warning(msg)
