"""PackageBuilderBase class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import csv
from dataclasses import dataclass, field
from datetime import datetime
import os
from pathlib import Path
from typing import Any

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules
from xms.constraint import Grid

# 4. Local modules
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.grid_info import GridInfo
from xms.mf6.data.gwf.chd_data import ChdData
from xms.mf6.data.list_package_data import ListPackageData
from xms.mf6.file_io import io_util
from xms.mf6.file_io.pest import pest_obs_data_generator
from xms.mf6.geom import shapefile_geom
from xms.mf6.gui.map_from_coverage_dialog import MapOpt
from xms.mf6.mapping.grid_intersector import IxInfo
from xms.mf6.mapping.tin_mapper import TinMapper
from xms.mf6.misc import log_util, util
from xms.mf6.misc.settings import Settings

# Constants
# Strings used as keys in dicts and in json files and as aux variable names
CELLGRP = 'CELLGRP'
CELLGRP_GEOM = 'CELLGRP_GEOM'

# Type aliases
OverrideLayers = dict[int, tuple[int, int]]  # old layer -> layer range (start and stop layer)


@dataclass
class CovMapInfo:
    """Mapping info for a coverage."""
    coverage_uuid: str = ''
    shapefile_names: dict[str, str] = field(default_factory=dict)
    ix_info: IxInfo = IxInfo()


CovMapInfoList = list[CovMapInfo]


@dataclass
class PackageBuilderInputs:
    """Inputs for PackageBuilder."""
    package: BaseFileData | None = None
    append_or_replace: MapOpt = MapOpt.APPEND
    cov_map_info_list: CovMapInfoList = field(default_factory=list)
    cogrid: Grid | None = None
    idomain: list[int] = field(default_factory=list)
    override_layers: bool = False
    layer_filepath: str = ''  # Filepath of layer file used to override grid layers


class PackageBuilderBase:
    """Builds a package during map from coverage."""
    def __init__(self, inputs: PackageBuilderInputs):
        """Initializes the class.

        Args:
            inputs: Everything needed to build the package.
        """
        self._package = inputs.package

        # Unpack things from inputs
        self._replace = inputs.append_or_replace == MapOpt.REPLACE
        self.cogrid = inputs.cogrid
        self.ugrid = inputs.cogrid.ugrid if inputs.cogrid is not None else None
        self._idomain = inputs.idomain
        self._override_layers = inputs.override_layers
        self._layer_filepath = inputs.layer_filepath

        # List of coverages and variables for the current one
        self._cov_map_info_list = inputs.cov_map_info_list
        self._coverage_uuid = ''  # Current coverage uuid
        self._shapefile_names = ''  # Current coverage shapefile names
        self._ix_info: IxInfo() | None = None  # Current coverage intersection info

        self._transient_data_filename: str = ''
        self._layer_range_exists = False
        self.log = log_util.get_logger()
        self._chd_cells: dict[int, set[int]] = {}  # Cells containing a CHD bc. Stress period numbers -> set of cellids
        self._feature_type = ''  # 'polygons', 'arcs', or 'points'
        self._max_group = 0
        self._groups: dict[str, int] = {}  # Dict of map_id -> group number
        self._group_geometry: dict[str, Any] = {}  # Dict of map_id -> feature geometry
        self._shape = None  # The current shapefile shape
        self._min_max_layers = None  # Cached from the grid
        self.period: int = 0  # Current stress period, 1-based
        self._period_times: list[float] | list[datetime] = []
        self.tin_mappers: dict[str, TinMapper] = {}  # tin filepath -> TinMapper (so, one TinMapper per tin)
        self._cellidx: int = -1  # 0-based cell index
        self._warned_about_time_compatibility: bool = False  # To only show the warning once
        self._warned_about_extrapolation: bool = False  # To only show the warning once
        self.map_import_info = None

        if self._package:
            self.grid_info = self._package.grid_info()  # So we don't have to keep getting it later
        self.trans_data = None  # Contents of the transient data file as a dict
        self.override_layers: OverrideLayers = self._init_override_layers(inputs.override_layers, inputs.layer_filepath)

    def _set_up_for_current_coverage(self, cov_map_info: CovMapInfo) -> None:
        """Set the variables to the current coverage info.

        Args:
            cov_map_info: Mapping info for a coverage.
        """
        self._coverage_uuid = cov_map_info.coverage_uuid
        self._shapefile_names = cov_map_info.shapefile_names
        self._ix_info = cov_map_info.ix_info

    def _init_override_layers(self, override_layers: bool, layer_filepath: str) -> OverrideLayers:
        """If overriding layers, reads the .csv file and sets up the dict of old layer -> to new layer range.

        Args:
            override_layers:
            layer_filepath:

        Returns:
            See description.
        """
        if not override_layers:
            return {}

        try:
            # Get grid layers
            min_layer, max_layer = self.grid_layer_range()
            if min_layer == -1 or max_layer == -1:
                raise RuntimeError('Cannot override grid layers because grid has no layers.')
            num_lay = max_layer - min_layer + 1

            # Check if file exists
            layer_filepath = Path(layer_filepath)
            if not layer_filepath.is_file():
                raise RuntimeError(f'Override grid layers file not found: {layer_filepath}')

            # Read the file
            try:
                df = pd.read_csv(layer_filepath)
            except Exception as e:
                raise RuntimeError(f'Error reading override grid layers file: {layer_filepath}\n{str(e)}')

            # Check that file has correct number of layers
            sublayer_count = df['Sublayers'].sum()
            if sublayer_count != num_lay:
                msg = f'Override grid layers file indicates {sublayer_count} layers, but grid has {num_lay} layers.'
                raise RuntimeError(msg)
        except RuntimeError as e:
            msg = str(e) + ' Grid layers will not be overridden.'
            self.log.warning(msg)
            return {}
        return _make_override_layers(df)  # Create the dict from the dataframe

    def build(self):
        """Maps the coverage to the package."""
        raise NotImplementedError

    def _set_up_period_times(self) -> None:
        """Get the period times from the TDIS package."""
        if self._package.model:  # This can be None when testing
            # Get period times as datetimes, if possible
            df = self._package.model.mfsim.tdis.get_period_times(as_date_times=True)
            self._period_times = df['Time'].to_list()

    def _turn_on_boundnames(self):
        """Turn on the BOUNDNAMES option."""
        if self._package.options_block.defined('BOUNDNAMES'):
            self._package.options_block.set('BOUNDNAMES', True, None)

    def _load_chd_cells(self):
        """Populates self._chd_cells."""
        if not self._package.mfsim:  # This can happen when testing
            return

        for model in self._package.mfsim.models:
            if model.ftype == 'GWF6':
                chds = model.packages_from_ftype('CHD6')
                for chd in chds:
                    update_chd_cells_from_package(self.grid_info, chd, self._chd_cells)

    def _does_layer_range_exist(self, reader) -> bool:
        """Looks at fields in shapefile and returns True if "FROM_LAYER" and "TO_LAYER" fields both exist.

        Args:
            reader: Shapefile reader.

        Returns:
            (bool): See description.
        """
        fields = reader.fields
        from_layer_exists = False
        to_layer_exists = False
        for shp_field in fields:
            if shp_field[0] == 'FROM_LAYER':
                from_layer_exists = True
            elif shp_field[0] == 'TO_LAYER':
                to_layer_exists = True
            if from_layer_exists and to_layer_exists:
                break
        return from_layer_exists and to_layer_exists

    def get_layer_range(self, record, layer_range_exists):
        """Returns the from_layer and to_layer.

        Args:
            record: A record from the shapefile.
            layer_range_exists (bool): True if the "FROM_LAYER" and "TO_LAYER" fields exist in the shapefile.

        Returns:
            (tuple): tuple containing:
                - (int): from_layer
                - (int): to_layer
        """
        if layer_range_exists:
            from_layer = record['FROM_LAYER']
            to_layer = record['TO_LAYER']
        else:
            from_layer, to_layer = self.grid_layer_range()
        return from_layer, to_layer

    def grid_layer_range(self):
        """Returns the min/max layers from the grid.

        Returns:
            (tuple[int, int]): The min and max layers, or (-1, -1) if we can't get them.
        """
        if self._min_max_layers:
            return self._min_max_layers  # Return the cached values

        # Find the min/max layers and cache them
        cell_layers = self.cogrid.cell_layers
        if cell_layers:
            self._min_max_layers = min(cell_layers), max(cell_layers)
        else:
            self._min_max_layers = -1, -1
        return self._min_max_layers

    def _get_initial_stuff(self, feature_type):
        """Returns some things we will need and avoids duplicating code.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'.

        Returns:
            (tuple): tuple containing:
                - (str): Shapefile name
                - (str): String used in att table for the attribute Type column (i.e. 'well', 'drain', 'river' etc)
                - (list(recarray)): The intersection recarrays
                - (dict): Dict with shapefile or att table fields.
        """
        shapefile_name = self._shapefile_names.get(feature_type)
        att_type = att_type_from_ftype(self._package.ftype)
        ix_recs = self._ix_info.get_list(feature_type)
        map_info = self._package.map_info(feature_type)
        self.map_import_info = self._package.map_import_info(feature_type)
        return shapefile_name, att_type, ix_recs, map_info

    def _transient_data_filepath_from_shape_filepath(self, shapefile_name):
        """Returns the name of the transient data file (.csv) that may be with the shapefile.

        Args:
            feature_type (str): 'points', 'arcs', or 'polygons'.

        Returns:
            (str): See description.
        """
        return os.path.splitext(shapefile_name)[0] + '.csv'

    def _read_transient_data(self, column_names, shapefile_name):
        """Read the transient data file into a dict.

        It's either this or we make it an sqlite database and do lots of querying.

        Args:
            column_names (list(str)): Names of the columns in the transient data file.
            shapefile_name (str): Filepath to shapefile name. Transient data file should be next to it.

        Returns:
            (dict): Dict of the CSV file.
        """
        self._transient_data_filename = self._transient_data_filepath_from_shape_filepath(shapefile_name)
        if not os.path.isfile(self._transient_data_filename):
            return None
        return read_trans_data_file(self._transient_data_filename, column_names)

    def _add_cellgrp_aux(self):
        """Adds the CELLGRP aux variable if it doesn't already exist."""
        options = self._package.options_block
        if options.defined('AUXILIARY') and (
            self._package.ftype not in {'EVT6', 'RCH6'} or isinstance(self._package, ListPackageData)
        ):  # noqa: W503 (line break)
            aux = options.get('AUXILIARY', [])
            if not aux:
                options.set('AUXILIARY', True, [CELLGRP])
            elif CELLGRP not in aux:
                aux.append(CELLGRP)
                options.set('AUXILIARY', True, aux)

    def _map_id(self, feature_id: int) -> str:
        """Returns the map_id for use with CELLGRP: coverage uuid, feature type, feature ID.

        Mapid is a string consisting of the coverage uuid, the feature type (point, arc, polygon), and the feature id.

        Args:
            feature_id (int): ID of the feature.

        Returns:
            (str): See description.
        """
        # Make the feature type singular. arcs -> arc, points -> point, polygons -> polygon
        feature_type = self._feature_type[:-1] if self._feature_type[-1] == 's' else self._feature_type
        return pest_obs_data_generator.make_map_id(self._coverage_uuid, feature_type, feature_id)

    def group(self, feature_id: int) -> int:
        """Returns the group number for the feature.

        Args:
            feature_id (int): ID of the feature.

        Returns:
            (int): See description.
        """
        map_id = self._map_id(feature_id)
        group = self._groups.get(map_id, None)
        if group is None:
            self._max_group += 1
            self._groups[map_id] = self._max_group
            group = self._max_group
            self._group_geometry[map_id] = shapefile_geom.geom_from_shape(self._shape)
        return group

    def _merge_group_data(self, replace, settings_json, key, data):
        """If we are appending, merges data with data in file and returns the result; if not, just returns data.

        Args:
            settings_json (dict): Data read from settings file.
            key (str): The dictionary key of the data.
            data: The data.

        Returns:
            (dict): The data.
        """
        file_data = settings_json.get(key)
        if file_data and not replace:
            new_data = {**file_data, **data}
        else:
            new_data = data
        return new_data

    def _save_groups_with_package(self):
        """Saves the group map_id and number in the package settings.json file."""
        file_json = Settings.read_settings(self._package.filename)
        new_groups = self._merge_group_data(self._replace, file_json, CELLGRP, self._groups)
        new_group_geometry = self._merge_group_data(self._replace, file_json, CELLGRP_GEOM, self._group_geometry)
        if new_groups:
            file_json[CELLGRP] = new_groups
            file_json[CELLGRP_GEOM] = new_group_geometry
            Settings.write_settings(self._package.filename, file_json)

    def _remove_cellgrp_info(self):
        """Removes cellgrp info from the package settings.json file."""
        cellgrp_info = Settings.get(self._package.filename, CELLGRP)
        if cellgrp_info:
            Settings.set(self._package.filename, CELLGRP, None)
            Settings.set(self._package.filename, CELLGRP_GEOM, None)

    def _read_cellgrp_info(self):
        """Reads the cellgrp info from the package settings.json file."""
        self._groups = Settings.get(self._package.filename, CELLGRP, {})
        self._group_geometry = Settings.get(self._package.filename, CELLGRP_GEOM, {})

    def _find_max_cellgrp(self):
        """Finds and returns the maximum CELLGRP number in use in other packages of the same ftype.

        CELLGRP numbers must be unique for packages of the same ftype because John Doherty's utilities only use the
        flow type text (e.g. RIVER LEAKAGE) and the CELLGRP aux number to identify a flow value.

        Returns:
            (int): Max CELLGRP number already in use by all packages of the same type.
        """
        if not self._package.model:  # For testing
            return 0

        mx = 0
        packages = self._package.model.packages_from_ftype(self._package.ftype)
        for package in packages:
            cellgrp_info = Settings.get(package.filename, CELLGRP)
            if cellgrp_info:
                for cellgrp in cellgrp_info.values():
                    mx = max(mx, cellgrp)
        return mx


def update_chd_cells_from_package(grid_info: GridInfo, chd_package: ChdData, chd_cells: dict[int, set[int]]):
    """Updates self._chd_cells by adding the CHD cells in chd_package."""
    if not chd_package.mfsim:  # This can happen when testing
        return

    cells = chd_package.get_all_chd_cellids()
    for period, cell_set in cells.items():
        if period not in chd_cells:
            chd_cells[period] = set()
        cellidx_set = {grid_info.cell_index_from_modflow_cellid(cellid) for cellid in cell_set}
        chd_cells[period].update(cellidx_set)


def cell_has_chd(period: int, cellidx: int, chd_cells: dict[int, set[int]]) -> bool:
    """Returns True if the cell contains a CHD bc.

    Args:
        period (int): The stress period.
        cellidx (int): A cell index.
        chd_cells:

    Returns:
        (bool): See description.
    """
    return period in chd_cells and cellidx in chd_cells[period]


def cell_active(idomain, cell_idx: int) -> bool:
    """Returns True if the cell is active (IDOMAIN > 0).

    Args:
        idomain: The idomain array
        cell_idx (int): A cell index.

    Returns:
        (bool): See description.
    """
    return not idomain or idomain[cell_idx] > 0


def cell_active_and_no_chd(idomain, period: int, cell_idx: int, chd_cells: dict[int, set[int]]) -> bool:
    """Returns True if the cell is active (IDOMAIN > 0) and there is no CHD BC in the cell.

    Args:
        idomain: The idomain array
        period (int): The stress period.
        cell_idx (int): A cell index.
        chd_cells: Cells containing a CHD bc. Stress period numbers -> set of cellids

    Returns:
        (bool): See description.
    """
    return cell_active(idomain, cell_idx) and not cell_has_chd(period, cell_idx, chd_cells)


def add_list_package_period_data(data: ListPackageData, period: int, period_rows: list) -> None:
    """Adds the data for the stress period to the package.

    Args:
        data: The package data class.
        period: The stress period.
        period_rows: List of rows for a stress period.
    """
    data.period_files[period] = ''
    data.append_period_data(period, period_rows)


def att_type_from_ftype(ftype):
    """Given an ftype, returns the corresponding string used in the Type column.

    Args:
        ftype (str): The file type used in the GWF name file (e.g. 'WEL6')

    Returns:
        (str): See description.
    """
    if ftype == 'CHD6':
        return 'spec. head (CHD)'
    elif ftype == 'DRN6':
        return 'drain'
    elif ftype == 'GHB6':
        return 'gen. head'
    elif ftype == 'HFB6':
        return 'barrier'
    elif ftype == 'MAW6':
        return 'well (MAW)'
    elif ftype == 'RIV6':
        return 'river'
    elif ftype == 'SFR6':
        return 'stream (SFR2)'
    elif ftype == 'WEL6':
        return 'well'


def read_trans_data_file(filepath: Path | str, column_names) -> dict:
    """Read the transient data file and return the info as a dict.

    Args:
        filepath: Trans data filepath.

    Returns:
        See description.
    """
    trans_data = {}
    with open(filepath, 'r') as trans_file:
        dict_reader = csv.DictReader(trans_file)
        for row in dict_reader:
            # Read the period and feature id
            period = int(row['Period'])
            trans_data.setdefault(period, {})
            feature_id = int(row['ID'])
            trans_data[period].setdefault(feature_id, {})

            # Add the rest of the row
            for field_name in dict_reader.fieldnames:
                if field_name not in {'Period', 'ID'}:
                    value = row[field_name]
                    if util.is_number(value):
                        trans_data[period][feature_id][field_name] = float(value)
                    else:
                        trans_data[period][feature_id][field_name] = value
    return trans_data


def should_skip_record(ix_rec, record, att_type: str) -> bool:
    """Return True if we should skip this record.

    Args:
        ix_rec: Intersection info for this record.
        record: The record.
        att_type: 'well', 'well (MAW)' etc.

    Returns:
        See description.
    """
    return len(ix_rec.cellids) == 0 or (att_type and record['Type'] != att_type)


def save_periods(period_rows, package):
    """Writes the periods to temporary files and saves filenames in the package.period_files dict."""
    for period, rows in period_rows.items():
        package.period_files[period] = io_util.write_lines_to_temp_file(rows)


def _make_override_layers(df: pd.DataFrame) -> OverrideLayers:
    """Return a dict of new layer -> layer range tuple given a dataframe containing two columns: Layer, Sublayers.

    Example:

    Layer,Sublayers
    1,3
    2,1
    3,3

    ...results in {1: (1,3), 2: (4, 4), 3: (5,7)}

    Args:
        df: The dataframe.

    Returns:
        See description.
    """
    override_layers: OverrideLayers = {}
    new_layer = 1
    for index, row in df.iterrows():
        sublayers = row['Sublayers']
        override_layers[index + 1] = (new_layer, new_layer + sublayers - 1)
        new_layer += sublayers
    return override_layers
