"""CellAdder class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from pathlib import Path

# 2. Third party modules

# 3. Aquaveo modules
from xms.core import time
from xms.core.filesystem import filesystem as fs
from xms.coverage.xy import xy_util
from xms.coverage.xy.xy_series import XySeries

# 4. Local modules
from xms.mf6.mapping.package_builder_base import CELLGRP
from xms.mf6.mapping.tin_mapper import TinMapper
from xms.mf6.misc import util
from xms.mf6.misc.util import XM_NODATA


class CellAdder:
    """Adds a cell BC."""
    def __init__(self, builder):
        """Initializes the class.

        Args:
            builder: The package builder.
        """
        self._builder = builder
        self._cell_idx1 = None
        self._cell_idx2 = None
        self._record = None
        self._period_rows = None

    def set_record(self, cell_idx1: int, cell_idx2: int, record, period_rows: list | None):
        """Set the record for a cell.

        Args:
            cell_idx1: Cell index
            cell_idx2: Cell index
            record: Shapefile record for the current arc.
            period_rows: List of rows for a stress period.
        """
        self._cell_idx1 = cell_idx1
        self._cell_idx2 = cell_idx2
        self._record = record
        self._period_rows = period_rows

    def add_bc_to_cell(self) -> None:
        """Add a boundary condition to a cell."""
        column_values = self._get_column_values()
        if column_values:
            cellid1 = self._builder.grid_info.modflow_cellid_from_cell_index(self._cell_idx1)
            cellid2 = None
            if self._cell_idx2 != XM_NODATA:
                cellid2 = self._builder.grid_info.modflow_cellid_from_cell_index(self._cell_idx2)
            append_period_row(cellid1, cellid2, column_values, self._period_rows)

    def _get_column_values(self) -> list:
        """Returns the column values from either the transient data or the shapefile.

        Returns:
            (list): Column values
        """
        # Get the column values from either the transient data or the shapefile
        feature_id = self._record['ID']
        if self._builder.trans_data:
            period = self._builder.period  # for short
            if period in self._builder.trans_data and feature_id in self._builder.trans_data[period]:
                column_dict = self._builder.trans_data[period][feature_id]
                column_values = self._column_values_from_dicts(column_dict, feature_id)
            else:
                column_values = None
        else:
            column_values = self._column_values_from_dicts(self._record, feature_id)
        return column_values

    def _column_values_from_dicts(self, row_or_record, feature_id: int):
        """Returns the list of column values (a row) given row_or_record and map_import_info.

        Args:
            row_or_record: Can be either a record, if reading the shapefile, or a row, if reading the CSV file.
            feature_id (int): ID of current feature object.

        Returns:
            (list): See description.
        """
        column_values = []
        for key in self._builder.map_import_info.keys():
            column_values.append(self._column_value(feature_id, key, row_or_record, None))
        return column_values

    def _column_value(self, feature_id: int, key: str, row_or_record, mult: float) -> float | int | str | XySeries:
        """Return value from row_or_record using key.

        Args:
            feature_id: ID of feature object.
            key: Dict key for row_or_record.
            row_or_record: Record from the shapefile, or row from the transient data file.
            mult: Multiplier. With arcs, length. For polys, intersected area.

        Returns:
            The value.
        """
        value = row_or_record.get(key, '')
        if key == CELLGRP:
            value = self._builder.group(feature_id)
        elif key == 'Name':  # boundname
            pass
        elif not value:
            value = 0.0  # Default for all aux
        elif util.is_number(value):
            value = float(value)
            if mult is not None:
                value *= mult
        elif Path(self._resolve_trans_data_filepath(value)).is_file():
            value = self._interpolate_from_tin(value)
            if mult is not None:
                value *= mult
        return value

    def _resolve_trans_data_filepath(self, value: str) -> str:
        """Return .

        Args:
            value: A value read from the transient data file.

        Returns:
            See description.
        """
        trans_data_dir = os.path.dirname(self._builder._transient_data_filename)
        return fs.resolve_relative_path(trans_data_dir, value)

    def _interpolate_from_tin(self, tin_filepath: str) -> float:
        """Interpolate from the tin and return the value for the current stress period.

        Args:
            tin_filepath: Path to .xmc file containing the tin as a UGrid.

        Returns:
            See description.
        """
        tin_filepath = self._resolve_trans_data_filepath(tin_filepath)

        # Get TinMapper and interpolate to current cell
        if tin_filepath not in self._builder.tin_mappers:
            tin_mapper = TinMapper(tin_filepath, self._builder.cogrid, self._builder.ugrid, self._builder._package)
            self._builder.tin_mappers[tin_filepath] = tin_mapper
        mapper = self._builder.tin_mappers[tin_filepath]
        result = mapper.interpolate(self._cell_idx1)

        # result is either a constant, or an XySeries
        if isinstance(result, XySeries):
            if not self._times_compatible_or_warn(result):  # Both TIN dataset and TDIS must use dates/times, or not
                return XM_NODATA

            # Get start and end stress period times
            start_time = self._builder._period_times[self._builder.period - 1]
            end_time = self._builder._period_times[self._builder.period]

            # Convert to floats to do xy series averaging
            if not util.is_number(start_time):
                x_floats = [time.datetime_to_julian(x) for x in result.x]
                result = XySeries(x_floats, result.y)
                start_time = time.datetime_to_julian(start_time)
                end_time = time.datetime_to_julian(end_time)

            # Get average value for stress period
            constant, flag = xy_util.average_y_from_x_range(result, start_time, end_time)
            if flag and not self._builder._warned_about_extrapolation:
                self._builder.log.warning(
                    'TIN dataset does not cover the range of times specified by the MODFLOW stress periods. Values will'
                    ' be extrapolated from the information that is available.'
                )
                self._builder._warned_about_extrapolation = True
        else:
            constant = result
        return constant

    def _times_compatible_or_warn(self, xy_series: XySeries) -> bool:
        """Return False if dataset and tdis don't both use floats, or both use datetimes, and log warning.

        Args:
            xy_series: The xy series from the TIN.

        Returns:
            See description.
        """
        if util.is_number(xy_series.x[0]) and not util.is_number(self._builder._period_times[0]):
            if not self._builder._warned_about_time_compatibility:
                self._builder.log.warning('TDIS package uses dates / times but TIN dataset does not.')
                self._builder._warned_about_time_compatibility = True
            return False
        elif not util.is_number(xy_series.x[0]) and util.is_number(self._builder._period_times[0]):
            if not self._builder._warned_about_time_compatibility:
                self._builder.log.warning('TIN dataset uses dates/times but TDIS package does not. ')
                self._builder._warned_about_time_compatibility = True
            return False
        return True


def append_period_row(cellid1, cellid2, column_values, period_rows):
    """Appends the row consisting of cellid followed by column_values to the period.

    Args:
        cellid1 (list): MODFLOW cellid (layer, row, column, if dis, etc.)
        cellid2 (list): MODFLOW cellid (layer, row, column, if dis, etc.)
        column_values (list): List of values for the cell for the MODFLOW variables for the package.
        period_rows (list): List of rows for a stress period.
    """
    id1_list = [ijk for ijk in cellid1] if not isinstance(cellid1, int) else [cellid1]
    id2_list = []
    if cellid2 is not None:
        id2_list = [ijk for ijk in cellid2] if not isinstance(cellid2, int) else [cellid2]
    row = [*id1_list, *id2_list, *column_values]
    period_rows.append(row)
