"""OcData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
from enum import IntEnum
from typing import Optional

# 2. Third party modules
import pandas as pd
from typing_extensions import override

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data import tdis_data
from xms.mf6.data.list_package_data import ListPackageData
from xms.mf6.data.options_block import OptionsBlock
from xms.mf6.data.tdis_data import TdisData
from xms.mf6.gui import gui_util
from xms.mf6.gui.options_defs import CheckboxField, CheckboxPrintFormat

# Constants
# Columns
PRINT_SAVE = 'PRINT_SAVE'  # 'PRINT' (to listing file), or 'SAVE' to budget file
RTYPE = 'RTYPE'  # 'BUDGET' or {'HEAD' or 'CONCENTRATION'}
OCSETTING = 'OCSETTING'  # 'ALL', 'FIRST', 'LAST', 'FREQUENCY', or 'STEPS'
FREQUENCY_STEPS = 'FREQUENCY_STEPS'  # A number if OCSETTING == FREQUENCY, list of numbers if OCSETTING == STEPS


class OcPresetOutputEnum(IntEnum):
    """Enumeration of OC package preset outputs."""
    OC_EVERY_TIME_STEP = 0  # Print and save budget and head at every time step
    OC_LAST_TIME_STEPS = 1  # Print and save budget and head at every last time step
    OC_USER_SPECIFIED = 2  # Do whatever you want
    OC_END = 3


def oc_first_word(model_ftype):
    """Returns either 'HEAD' or 'CONCENRATION' based on the model type.

    Aargs:
        model_ftype (str): Ftype of the model ('GWF6', 'GWT6').

    Returns:
        (str): See description
    """
    return {'GWF6': 'HEAD', 'GWT6': 'CONCENTRATION', 'GWE6': 'TEMPERATURE', 'PRT6': 'TRACK'}[model_ftype]


def fileout_extension(name):
    """Returns the file extension (.cbc for BUDGET, .hed for HEAD, .ucn for CONCENTRATION) for the FILEOUT option.

    Args:
        name (str): Name of option or widget containing either 'budget', 'head', or 'concentration' (case insensitive)

    Returns:
        (str): The file extension including the '.'
    """
    extension = ''
    if 'budget' in name.lower():
        extension = '.cbc'
    elif 'head' in name.lower():
        extension = '.hds'
    elif 'concentration' in name.lower():
        extension = '.ucn'
    elif 'temperature' in name.lower():
        extension = '.ucn'
    return extension


class OcData(ListPackageData):
    """Data class to hold the info from the OC package file."""
    def __init__(self, **kwargs):
        """Initializes the class.

        Args:
            **kwargs: Arbitrary keyword arguments.

        Keyword Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            mfsim (MfsimData): The simulation.
            model (GwfData or GwtData): The GWF/GWT model. Will be None for TDIS, IMS, Exchanges (things below mfsim)
            grid_info (GridInfo): Information about the grid. Only used when testing individual packages. Otherwise,
             it comes from model and dis
        """
        super().__init__(**kwargs)
        self.ftype = 'OC6'
        self.preset_output: OcPresetOutputEnum = OcPresetOutputEnum.OC_EVERY_TIME_STEP
        self.block_with_cellids = ''

    def get_column_tool_tips(self, block: str) -> dict[int, str]:
        """Returns a dict with column index and tool tip.

        Args:
            block (str): Name of the block.
        """
        names, _types, _defaults = self.package_column_info('')
        return {
            names.index(PRINT_SAVE):
                'Info will be printed or saved this stress period',
            names.index(RTYPE):
                'Type of information to be printed or saved',
            names.index(OCSETTING):
                'The steps for which the data will be printed or saved',
            names.index(FREQUENCY_STEPS):
                ('If OCSETTING is FREQUENCY, a single integer. If OCSETTING is STEPS,'
                 ' a list of integers.'),
        }

    def get_column_info(self, block: str, use_aux: bool = True):
        """Returns column names, types, and defaults.

        The columns depend on the DIS package in use and the AUX variables.
        The package specific and AUX columns are type object because they
        might contain time series strings.

        Args:
            block (str): Name of the list block.
            use_aux (bool): True to include AUXILIARY variables.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        return self.package_column_info()

    def package_column_info(self, block=''):
        """Returns the column info just for the columns unique to this package.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - types (dict of str -> type): Column names -> column types.
                - default (dict of str -> value): Column names -> default values.
        """
        columns = {
            PRINT_SAVE: (object, 'PRINT'),
            RTYPE: (object, 'BUDGET'),
            OCSETTING: (object, 'ALL'),
            FREQUENCY_STEPS: (object, '')
        }
        return gui_util.column_info_tuple_from_dict(columns)

    # @overrides
    def get_column_delegate_info(self, block):
        """Returns a list of tuples of [0] column index and [1] list of strings.

        Returns:
            (tuple): tuple containing:
                - index (int): Column index.
                - strings (list of str): List of strings allowed in the column.
        """
        delegate_info = None
        names, _, _ = self.get_column_info(block)
        first_word = oc_first_word(self.model.ftype) if self.model else 'HEAD'
        if not block or block == 'PERIODS':
            delegate_info = [
                (
                    names.index(PRINT_SAVE),
                    ['PRINT', 'SAVE'],
                ), (
                    names.index(RTYPE),
                    ['BUDGET', first_word],
                ), (
                    names.index(OCSETTING),
                    ['ALL', 'FIRST', 'LAST', 'FREQUENCY', 'STEPS'],
                )
            ]
        return delegate_info

    def stress_id_columns(self):
        """Returns the column name where the id exists that can be used to help identify this stress across periods.

        Typically is 'CELLIDX' which is added by GMS but is 'RNO' for SFR.

        Returns:
            See description.
        """
        return []

    def dialog_title(self):
        """Returns the title to show in the dialog.

        Returns:
            (str): The dialog title.
        """
        return 'Output Control (OC) Dialog'

    @override
    def _setup_options(self) -> OptionsBlock:
        """Returns the definition of all the available options.

        Returns:
            See description.
        """
        first_word = oc_first_word(self.model.ftype) if self.model else 'HEAD'
        return OptionsBlock(
            [
                CheckboxField('BUDGET FILEOUT', brief='Save budget to file', type_='str'),
                CheckboxField('BUDGETCSV FILEOUT', brief='Save budget to CSV file', type_='str'),
                CheckboxField(f'{first_word} FILEOUT', brief=f'Save {first_word.lower()} to file', type_='str'),
                CheckboxPrintFormat(f'{first_word} PRINT_FORMAT', brief='Format for printing to the listing file'),
            ]
        )

    def output_times(self, tdis: TdisData):
        """Returns a dict of the periods and list of timesteps where BUDGET is being saved.

        Args:
            tdis: TDIS package.

        Returns:
            (dict): See description.
        """
        # Uses another function for ease of testing.
        return output_times(
            tdis.period_df, tdis.get_start_date_time(), tdis.get_time_units(), self.get_period_df(-1),
            self.preset_output
        )

    def _get_output_filename(self, first_word):
        """Returns the head or concentration file name."""
        opts = self.options_block.options
        return opts.get(f'{first_word} FILEOUT', '')

    def get_dv_filename(self):
        """Returns the head or concentration file name."""
        return self._get_output_filename(oc_first_word(self.model.ftype))

    def get_budget_file_name(self):
        """Returns the budget file name."""
        return self._get_output_filename('BUDGET')


def _output_times_for_period(df: pd.DataFrame, times_dict, period: int) -> dict:
    """Return the output times for the period.

    Args:
        df:
        times_dict:
        period:

    Returns:
        See description.
    """
    times = {}
    for _, row in df.iterrows():
        if row[PRINT_SAVE] == 'SAVE' and row[RTYPE] == 'BUDGET':
            ocsetting = row[OCSETTING]
            if ocsetting == 'ALL':
                times = {index: date_time for index, date_time in enumerate(times_dict[period])}
            elif ocsetting == 'FIRST':
                times[0] = times_dict[period][0]
            elif ocsetting == 'LAST':
                times[len(times_dict[period]) - 1] = times_dict[period][-1]
            elif ocsetting == 'FREQUENCY':
                frequency = int(row[FREQUENCY_STEPS])
                for index in range(frequency - 1, len(times_dict[period]), frequency):
                    times[index] = times_dict[period][index]
            elif ocsetting == 'STEPS':
                steps = row[FREQUENCY_STEPS]
                steps_list = list(map(int, steps.split()))
                for step in steps_list:
                    times[step - 1] = times_dict[period][step - 1]
            else:
                raise ValueError()
    return times


def output_times(
    period_df: pd.DataFrame, start_date_time: datetime, time_units: str, oc_df: Optional[pd.DataFrame],
    preset: OcPresetOutputEnum
):
    """Given TDIS and OC info, returns a dict of the periods and list of time steps where output is specified.

    Args:
        period_df: The TDIS dataframe with at least columns 'PERLEN', 'NSTP', 'TSMULT'.
        start_date_time: Starting date/time.
        time_units: 'YEARS', 'DAYS', 'HOURS', 'MINUTES', or 'SECONDS'
        oc_df: OC dataframe ('PRINT_SAVE', 'RTYPE', 'OCSETTING', 'FREQUENCY_STEPS', 'PERIOD').
        preset: Preset: OC_EVERY_TIME_STEP, OC_LAST_TIME_STEPS, OC_USER_SPECIFIED

    Returns:
        (dict): See description.
    """
    if not start_date_time:
        start_date_time = datetime(1950, 1, 1)  # PEST assumes 01/01/1950 is 0.
    times_dict = tdis_data.get_all_timestep_date_times(period_df, start_date_time, time_units)
    outputs = {}
    if oc_df is None or preset == OcPresetOutputEnum.OC_LAST_TIME_STEPS:
        # If no OC package, default is "at the end of every stress period (mf6io.pdf)"
        for period, times in times_dict.items():
            outputs[period] = {len(times) - 1: times[-1]}
    elif preset == OcPresetOutputEnum.OC_EVERY_TIME_STEP:
        for period, _times in times_dict.items():
            outputs[period] = {index: date_time for index, date_time in enumerate(times_dict[period])}
    else:
        # "The information specified in the PERIOD block will continue to apply for all subsequent stress periods,
        # unless the program encounters another PERIOD block"
        prev_period = -1
        prev_df = None
        for period, _times in times_dict.items():
            df = oc_df[oc_df['PERIOD'] == period].copy()  # Get df of rows for just this period
            if len(df) > 0:
                times = _output_times_for_period(df, times_dict, period)
                if times:
                    outputs[period] = times
                prev_df = df
                prev_period = period
            elif prev_period != -1:
                # We've skipped a period. Apply previous period to this period
                times = _output_times_for_period(prev_df, times_dict, period)
                if times:
                    outputs[period] = times
                prev_period = period
    return outputs
