"""ArrayPackageData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
from datetime import datetime
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.mf6.data.array import Array
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.block import Block
from xms.mf6.data.grid_info import GridInfo
from xms.mf6.file_io import io_util
from xms.mf6.gui import units_util
from xms.mf6.misc import util
from xms.mf6.misc.util import XM_NODATA


class ArrayPackageData(BaseFileData):
    """Data class to hold the info from a list package file."""
    def __init__(self, **kwargs):
        """Initializes the class.

        Args:
            **kwargs: Arbitrary keyword arguments.

        Keyword Args:
            ftype (str): The file type used in the GWF name file (e.g. 'WEL6')
            mfsim (MfsimData): The simulation.
            model (GwfData or GwtData): The GWF/GWT model. Will be None for TDIS, IMS, Exchanges (things below mfsim)
            grid_info (GridInfo): Information about the grid. Only used when testing individual packages. Otherwise,
             it comes from model and dis
        """
        super().__init__(**kwargs)
        self._griddata_names = None  # The standard array names ('IRCH', 'RECHARGE'), never including aux
        self.period_data: dict[int, Block] = {}  # Dict of 1-based stress periods and Block
        self._layer_indicator = ''

    def tab_names(self, use_aux: bool = True) -> list[str]:
        """Return the list of the standard array names ('IRCH', 'RECHARGE') and aux variables."""
        return self._griddata_names + (self.options_block.get('AUXILIARY', []) if use_aux else [])

    def get_units(self, array_name: str) -> str:
        """Returns the units string for the units that belong to the array, like 'L' or 'L^3/T'.

        You should override this method.

        Args:
            array_name (str): The name of a array.

        Returns:
            (str): See description.
        """
        raise NotImplementedError()

    def is_required_array(self, array_name: str) -> bool:
        """Returns True if the array is required.

        Args:
            array_name: The name of an array.

        Returns:
            (bool): True or False
        """
        return False

    def dialog_title(self) -> str:
        """Returns the title to show in the dialog.

        You should override this method.

        Returns:
            (str): The dialog title.
        """
        raise NotImplementedError()

    def is_layer_indicator(self, array_name) -> bool:
        """Returns True if the array is a layer indicator ('IEVT', 'IRCH').

        You should override this method.

        Returns:
            (bool): True or False
        """
        raise NotImplementedError()

    def array_supports_time_array_series(self, array_name) -> bool:
        """Returns True if the array can have time-array series.

        You should override this method.

        Returns:
            (bool): True or False
        """
        raise NotImplementedError()

    def get_array(self, sp: int, array_name: str) -> Array | None:
        """Returns the Array object at sp with name array_name, or None.

        Args:
            sp: 1-based stress period.
            array_name: array name (e.g. 'IRCH', 'RECHARGE').

        Returns:
            See description.
        """
        block = self.period_data.get(sp)
        if block:
            return block.array(array_name)
        return None

    def add_transient_array(self, sp: int, array_name: str, replace_existing=True) -> Array:
        """Adds Array object at sp with name array_name.

        Args:
            sp: 1-based stress period.
            array_name: array name (e.g. 'IRCH', 'RECHARGE').
            replace_existing (bool): True to replace any data already associated with the array_name.

        Returns:
            (Array): The array data.
        """
        block = self.period_data.get(sp)
        if not block:
            return None
        if replace_existing or not block.has(array_name):
            array = Array(array_name)
            block.add_array(array)
        return block.array(array_name)

    def add_period(self, sp: int, replace_existing=True):
        """Adds data for the given stress period.

        Args:
            sp: The 1-based stress period.
            replace_existing (bool): True to replace any data already associated with the period.
        """
        if replace_existing or sp not in self.period_data:
            self.period_data[sp] = Block(name=f'PERIOD {sp}', array_names=[])

    def copy_period(self, from_sp: int, to_sp: int):
        """Copies a stress period.

        Args:
            from_sp: The stress period being copied.
            to_sp: The stress period being copied to.
        """
        self.period_data[to_sp] = copy.deepcopy(self.period_data[from_sp])

        # Copy external files
        for array in self.period_data[to_sp].arrays:
            if array.storage == 'ARRAY':
                array = array.layer(0)
                if array.temp_external_filename:
                    from_filename = array.temp_external_filename
                else:
                    from_filename = array.external_filename

                # from_filename = filesystem.resolve_relative_path(self.filename, from_filename)

                if os.path.isfile(from_filename):
                    array.temp_external_filename = io_util.copy_file_to_temp(from_filename)
                else:
                    array.temp_external_filename = ''

    def fill_missing_periods(self, sp_count: int) -> None:
        """Fills in missing stress periods (where an undefined period means use the previously defined one).

        Args:
            sp_count: Number of stress periods.
        """
        prev_period = None  # 1-based stress period number
        for sp_idx in range(sp_count):
            sp = sp_idx + 1
            if sp in self.period_data:
                if self.period_data[sp]:  # Stress period is defined and not empty
                    prev_period = sp
                else:  # Stress period is defined but empty. Reset prev_period
                    prev_period = None
            elif prev_period is not None:
                # Fill in undefined periods by copying from the last defined period
                for i in range(prev_period + 1, sp + 1):
                    self.copy_period(prev_period, i)
                prev_period = sp

    def get_tool_tip(self, tab_text: str) -> str:
        """Returns the tool tip that goes with the tab with the tab_name.

        Args:
            tab_text: Text of the tab

        Returns:
            (str): The tool tip.
        """
        return ''

    def make_datasets(self, array_names: list[str], ugrid_uuid: str) -> list[DatasetWriter]:
        """Create and return DatasetWriters for each array listed in array_names.

        Args:
            array_names: Names of the arrays to create datasets from.
            ugrid_uuid: uuid of the UGrid.

        Returns:
            List of DatasetWriter.
        """
        from xms.mf6.components import arrays_to_datasets as atd  # Imported here to avoid circular import

        start_date_time, period_times, time_units = self._get_time_info()
        grid_info = self.grid_info()
        cell_count = grid_info.cell_count()

        # Iterate through stress periods and arrays, creating datasets
        writers: dict[str, DatasetWriter] = {}
        for period in self.period_data.keys():
            time = period_times[period - 1]
            for name in array_names:
                array = self.get_array(period, name)
                if not array or array.storage == 'UNDEFINED':
                    continue

                is_layer_indicator = name == self._layer_indicator

                # If the current array is not the layer indicator array, get the layer indicator array
                layers: list[int] = [] if is_layer_indicator else self._dset_layers(period)

                # Get the list of values and make it the same size as the grid
                values = atd.resize_values(cell_count, array.get_values(apply_factor=True))

                # Put values in the layer specified by layer indicator array
                if layers:
                    self._put_values_in_indicated_layer(grid_info, layers, values)

                # Create activity array if not the layer indicator array
                activity = None if is_layer_indicator else atd.make_activity(values)

                # Append the time step
                if name not in writers:
                    # Create a DatasetWriter
                    units = units_util.string_from_units(self, self.get_units(name))
                    writer = atd.create_dataset_writer(name, ugrid_uuid, start_date_time, time_units, units)
                    writers[name] = writer
                else:
                    writer = writers[name]
                writer.append_timestep(time=time, data=values, activity=activity)

        # Finish all datasets
        for writer in writers.values():
            writer.appending_finished()
        return list(writers.values())

    def _put_values_in_indicated_layer(self, grid_info: GridInfo, layers: list[int], values: list[float]) -> None:
        """Move the values to the layer indicated by layers.

        Args:
            grid_info: GridInfo object.
            layers: The layers the values should be in.
            values: The array values.
        """
        for i, layer in enumerate(layers):
            layer = int(round(layer))  # Convert float to int
            if layer != 1:
                layer = util.clamp(layer, 1, grid_info.nlay)
                cell_idx = grid_info.cell_index_from_lay_cell2d(layer, i + 1, one_based=True)
                values[cell_idx] = values[i]
                values[i] = XM_NODATA

    def _dset_layers(self, period: int) -> list[int]:
        """If the current array is not the layer indicator array, return the layer indicator array as a list.

        Args:
            period: The stress period.

        Returns:
            See description.
        """
        layers: list[int] = []
        layer_indicator_array = self.get_array(period, self._layer_indicator)
        if layer_indicator_array and layer_indicator_array.storage != 'UNDEFINED':
            layers = layer_indicator_array.get_values()
        return layers

    def _get_time_info(self) -> tuple[datetime | float, list[float], str]:
        """Return the start date/time, period times, and time units.

        Returns:
            See description.
        """
        # Get time info from TDIS
        tdis = self.mfsim.tdis
        start_date_time = tdis.get_start_date_time()
        tdis_time_units = tdis.get_time_units()
        time_units = units_util.dataset_time_units_from_tdis_time_units(tdis_time_units)
        period_times_df = tdis.get_period_times(as_date_times=False)
        period_times = period_times_df['Time'].to_list()
        return start_date_time, period_times, time_units
