"""TdisReader class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import pandas

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.file_io import io_factory
from xms.mf6.file_io.package_reader import PackageReader


def time_units_and_start_date_time(tdis_main_file):
    """Returns the TIME_UNITS and START_DATE_TIME options in the TDIS main file, or '' if not found.

    Args:
        tdis_main_file (str): File path to TDIS mainfile.

    Returns:
        (tuple[str, str]): See description.
    """
    reader = io_factory.reader_from_ftype('TDIS6')
    tdis_data = reader.read(tdis_main_file)
    time_units = tdis_data.options_block.get('TIME_UNITS', '')
    start_date_time = tdis_data.options_block.get('START_DATE_TIME', '')
    return time_units, start_date_time


class TdisReader(PackageReader):
    """Reads a TDIS package file."""
    def __init__(self):
        """Initializes the class."""
        super().__init__(ftype='TDIS6')
        self._dtypes = []
        self._row = 0

    def _read_options(self, line):
        """Reads a line from the options block.

        Args:
            line (str): A line from the file.

        Returns:
            (tuple(str, str)): tuple containing:
            - key: str
            - value: str or may be list(str) or may be something else
        """
        key, value = super()._read_options(line)
        # Make sure time units are upper case
        if key == 'TIME_UNITS':
            value = value.upper()
            self._data.options_block.set(key, True, value)
        return key, value

    def _read_dimensions(self, line):
        """Reads the dimensions block.

        Args:
            line (str): A line from the file.
        """
        words = line.split()
        if words and words[0] == 'NPER' and len(words) > 1:
            self._data._nper = int(words[1])
            self.initialize_period_dataframe()

    def initialize_period_dataframe(self):
        """Initializes period data with the right number of rows and columns."""
        # Initialize pandas DataFrame

        columns, defaults = self._data.get_column_info('')
        data = [defaults for i in range(self._data._nper)]
        self._data.period_df = pandas.DataFrame(data=data, columns=columns)
        self._data.period_df.index += 1

        # Get the data types for each column so we can cast the strings below
        for column in range(self._data.num_columns_in_file):  # Skip last column (EndDate)
            column_name = self._data.period_df.columns[column]
            self._dtypes.append(self._data.period_df[column_name].dtype)

    def _read_perioddata(self, line):
        """Reads the stress period data.

        Args:
            fp (_io.TextIOWrapper): The file.
            words (list of str): The list of words from the last line read.

        Returns:
            (str): The next line.
        """
        line = line.strip().replace(',', ' ')
        words = line.split()
        if words and len(words) >= 3:
            for column in range(self._data.num_columns_in_file):  # Skip last 2 columns (EndDate, Steady-state)
                self._data.period_df.iloc[self._row, column] = self._dtypes[column].type(words[column])
            self._row += 1
