"""TdisData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime
from enum import IntEnum
from itertools import accumulate

# 2. Third party modules
import pandas as pd
from typing_extensions import override

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data import time_util
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.options_block import OptionsBlock
from xms.mf6.gui import units_util
from xms.mf6.gui.options_defs import CheckboxButton, CheckboxComboBox, CheckboxFieldButton
from xms.mf6.gui.units_util import MF6_TIME_UNITS


class TdisColumn(IntEnum):
    """The table columns."""
    PERLEN = 0
    NSTP = 1
    TSMULT = 2
    END_DATE = 3
    STEADY_STATE = 4


class TdisData(BaseFileData):
    """Data class to hold the info from a tdis package file."""
    def __init__(self, **kwargs):
        """Initializes the class.

        Args:
            **kwargs: Arbitrary keyword arguments.

        Keyword Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            mfsim (MfsimData): The simulation.
            model (GwfData or GwtData): The GWF/GWT model. Will be None for TDIS, IMS, Exchanges (things below mfsim)
            grid_info (GridInfo): Information about the grid. Only used when testing individual packages. Otherwise,
             it comes from model and dis
        """
        super().__init__(**kwargs)
        self.ftype = 'TDIS6'
        # Misc
        self.period_df = None  # Pandas DataFrame
        self.column_names = ['PERLEN', 'NSTP', 'TSMULT', 'ENDDATE']
        self.column_defaults = {'PERLEN': 1.0, 'NSTP': 1, 'TSMULT': 1.0, 'ENDDATE': ''}
        self.num_columns_in_file = 3  # Only the first 3 columns appear in the file
        self._nper = 0  # Number of stress periods

    @override
    def nper(self) -> int:
        """Return the number of stress periods.

        Returns:
            See description.
        """
        return self._nper

    def get_time_units(self) -> str:
        """Returns the TIME_UNITS string ('UNKNOWN', 'SECONDS', 'MINUTES', 'HOURS', 'DAYS', 'YEARS').

        Returns:
            (str): See description.
        """
        return self.options_block.get('TIME_UNITS')

    def get_start_date_time(self) -> datetime | float:
        """Returns START_DATE_TIME as a datetime if in use, 0.0 if not in use.

        Returns:
            (datetime or float): See description.
        """
        start_date_time_str = self.options_block.get('START_DATE_TIME')
        if start_date_time_str:
            return time_util.datetime_from_arbitrary_string(start_date_time_str)
        else:
            return 0.0

    def get_units(self, array_name: str) -> str:
        """Returns the units string for the array.

        Args:
            array_name (str): The name of an array.

        Returns:
            (str): The units string like 'L' or 'L^3/T'.
        """
        match array_name:
            case 'PERIODDATA':
                units_spec = units_util.UNITS_TIME
            case _:
                units_spec = ''  # This is an error
        return units_spec

    def set_start_date_time(self, start_date_time: str) -> None:
        """Sets START_DATE_TIME."""
        # self.options_block.set('START_DATE_TIME', True, '2010-05-18T01:00:00')
        self.options_block.set('START_DATE_TIME', True, start_date_time)

    def set_time_units(self, time_units: str) -> None:
        """Sets TIME_UNITS."""
        self.options_block.set('TIME_UNITS', True, time_units)

    def get_end_date_time(self):
        """Returns the end time (not the last period starting time) as a float or a python datetime.

        Returns:
            (float or datetime): See description.
        """
        df = self.get_period_times(as_date_times=True)
        end_time = df['Time'].to_list()[-1]
        if isinstance(end_time, pd.Timestamp):
            return end_time.to_pydatetime()
        return end_time

    def get_period_count(self):
        """Returns the number of stress periods.

        Returns:
            (int): Number of periods.
        """
        if self.period_df is None or self.period_df.empty:
            return None  # pragma no cover - This should never happen
        return self.period_df.shape[0]

    def can_do_date_times(self) -> bool:
        """Return True if period times can be date/times."""
        return _can_do_date_times(self.options_block.get('START_DATE_TIME'), self.get_time_units())

    def get_period_times(self, as_date_times=True) -> pd.DataFrame | None:
        """Returns a Pandas dataframe of period numbers and times (relative or date/times) based on period lengths.

        Size will be one more than the number of periods to include the last period end time.

        Args:
            as_date_times (bool): If true and a starting date/time is used, times will be dates/times.

        Returns:
            times (Pandas.Series): See description. Columns are Index, PERIOD, Time.
        """
        if self.period_df is None or self.period_df.empty:
            return None

        if not as_date_times or not self.can_do_date_times():
            return self._get_relative_times()
        else:
            return self._get_date_times()

    def _get_relative_times(self) -> pd.DataFrame:
        """Returns a DataFrame with relative period start times (and 1 end time) starting at 0.0.

        Size is 1 plus the number of periods. The last row is the end time of the last period.

        Returns:
            See description.
        """
        times = [0.0]

        # Calculate new end date/times
        for row in range(self.period_df.shape[0]):
            perlen = self.period_df.iloc[row, TdisColumn.PERLEN]
            times.append(times[-1] + perlen)

        df = pd.DataFrame(data={'Time': times})
        df.index += 1
        df.insert(0, 'PERIOD', df.index)
        return df

    def _get_date_times(self) -> pd.DataFrame:
        """Returns a DataFrame with period start dates (and 1 end date) as date/times.

        Size is 1 plus the number of periods. The last row is the end time of the last period.

        Returns:
            See description.
        """
        units = self.get_time_units()
        date_time = time_util.datetime_from_arbitrary_string(self.options_block.get('START_DATE_TIME'))
        times = self.get_date_times(units, date_time, self.period_df.iloc[:, TdisColumn.PERLEN].tolist())
        df = pd.DataFrame(data={'Time': times})
        df.index += 1
        df.insert(0, 'PERIOD', df.index)
        return df

    @staticmethod
    def get_date_times(time_units: str, date_time: datetime, durations: list[float]):
        """Returns a list with period start dates (and 1 end date) as date/times.

        Size is 1 plus the number of periods. The last row is the end time of the last period.

        Args:
            time_units: The time units string.
            date_time: The starting date.
            durations: durations.

        Returns:
            (list[datetime]): See description.
        """
        times = [date_time]

        # Calculate new end date/times
        prev_end_date = date_time
        for duration in durations:
            end_date = time_util.compute_end_date_py(prev_end_date, duration, time_units)
            prev_end_date = end_date
            times.append(end_date)
        return times

    # @overrides
    def get_column_info(self, block, use_aux=True):
        """Returns column names, and defaults.

        Args:
            block (str): Name of the list block.
            use_aux (bool): True to include AUXILIARY variables.

        Returns:
            (tuple): tuple containing:
                - column_names (list): Column names.
                - default (dict of str -> value): Column names -> default values.
        """
        return self.column_names, self.column_defaults

    def get_column_tool_tips(self, block: str) -> dict[int, str]:
        """Returns a dict with column index and tool tip.

        Args:
            block (str): Name of the block.
        """
        names, _defaults = self.get_column_info('')
        units_str = units_util.string_from_units(self, '[T]')
        return {
            names.index('PERLEN'): f'Length of a stress period {units_str}',
            names.index('NSTP'): 'Number of time steps in a stress period',
            names.index('TSMULT'): 'Multiplier for the length of successive time steps',
            names.index('ENDDATE'): 'DATE/TIME at the end of the stress period',
        }

    def dialog_title(self):
        """Returns the title to show in the dialog.

        Returns:
            (str): The dialog title.
        """
        return 'Temporal Discretization (TDIS) Package'

    # @overrides
    def _setup_options(self):
        """Returns the definition of all the available options.

        Returns:
            (OptionsBlock): See description.
        """
        return OptionsBlock(
            [
                CheckboxComboBox(
                    'TIME_UNITS',
                    brief='Time unit label used in output files',
                    items=['UNKNOWN', 'SECONDS', 'MINUTES', 'HOURS', 'DAYS', 'YEARS'],
                    value='unknown',
                    check_box_method='on_chk_time_units',
                    combo_box_method='on_cbx_time_units'
                ),
                CheckboxFieldButton(
                    'START_DATE_TIME',
                    brief='Starting date and time. Included in list file',
                    type_='str',
                    value=None,
                    button_text='Date/Time...',
                    check_box_method='on_chk_start_date_time',
                    field_method='on_edt_start_date_time',
                    button_method='on_btn_start_date_time'
                ),
                CheckboxButton(
                    'ATS6 FILEIN',
                    brief='Adaptive time step input file',
                    button_text='Edit...',
                    check_box_method='on_chk_ats6_filein',
                    button_method='on_btn_ats6_filein'
                ),
            ]
        )


def timestep_lengths(perlen: float, nstp: int, tsmult: float):
    """Returns a list of the lengths of the timesteps for one period given the period length and multiplier.

    See mf6io.pdf documentation of tsmult in TDIS package.

    Args:
        perlen(float): Length of the period.
        nstp (int): Number of timesteps.
        tsmult (float): Multiplier (e.g. 1.2)

    Returns:
        (list): See description.
    """
    lengths = []
    if tsmult == 1.0:
        return [perlen / nstp] * nstp

    t = perlen * (tsmult - 1) / (tsmult**nstp - 1)
    lengths.append(t)
    for _ in range(1, nstp):
        t *= tsmult
        lengths.append(t)
    return lengths


def get_all_timestep_relative_times(period_df):
    """Returns a dict of period number and list of timestep relative times.

    Tries to avoid accumulating round off error.

    Args:
        period_df (pandas.DataFrame): The TDIS dataframe with at least columns 'PERLEN', 'NSTP', 'TSMULT'.

    Returns:
        (dict): See description.
    """
    periods_dict = {}
    previous_start = 0.0
    for period, row in enumerate(period_df.itertuples(index=False)):
        lengths = timestep_lengths(row.PERLEN, row.NSTP, row.TSMULT)
        sums = list(accumulate(lengths))
        timestep_times = []
        for value in sums:
            timestep_times.append(previous_start + value)
        previous_start += row.PERLEN
        timestep_times[-1] = previous_start  # Overwrite last one to avoid accumulating round off error
        periods_dict[period + 1] = timestep_times
    return periods_dict


def get_all_timestep_date_times(period_df, start_date_time: datetime, units: str):
    """Returns a dict of period number and list of timestep datetimes.

    Args:
        period_df (pandas.DataFrame): The TDIS dataframe with at least columns 'PERLEN', 'NSTP', 'TSMULT'.
        start_date_time (datetime): Starting date/time.
        units (str): 'YEARS', 'DAYS', 'HOURS', 'MINUTES', or 'SECONDS'

    Returns:
        (dict): See description.
    """
    times_dict = get_all_timestep_relative_times(period_df)
    for _period, times in times_dict.items():
        for index in range(len(times)):
            times[index] = time_util.compute_end_date_py(start_date_time, times[index], units)
    return times_dict


def _are_units_defined(units):
    """Returns true if units are defined.

    Args:
        units (str): The time units string.

    Returns:
        See description.
    """
    return units and (units.upper() in MF6_TIME_UNITS) and units.upper() != 'UNKNOWN'


def _can_do_date_times(start_date_time, units) -> bool:
    """Return True if period times can be date/times."""
    if not start_date_time:
        return False
    date_time = time_util.datetime_from_arbitrary_string(start_date_time)
    return _are_units_defined(units) and date_time is not None and isinstance(date_time, datetime)
