"""Array class."""
from __future__ import annotations  # So we can use Array type hint in a static method

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
from distutils.text_file import TextFile
import typing
# from typing_extensions import Self  # This doesn't seem to work in Python 3.10

# 2. Third party modules
import pandas as pd

# 3. Aquaveo modules

# 4. Local modules
from xms.mf6.data.array_layer import ArrayLayer
from xms.mf6.data.grid_info import DisEnum, GridInfo
from xms.mf6.file_io.array_layer_reader import ArrayLayerReader
from xms.mf6.file_io.array_layer_writer import ArrayLayerWriter
from xms.mf6.misc import log_util, util


class Array:
    """Class that holds array data, such as is used in array-based packages (EVT, RCH)."""
    def __init__(self, array_name='', storage='CONSTANT', time_array_series='', layered=False):
        """Initializes the class.

        Args:
            array_name (str): Name for this array (e.g. 'RECHARGE')
            storage (str): 'UNDEFINED', 'CONSTANT', 'ARRAY', 'TIME-ARRAY SERIES'
            time_array_series (str): Name of the time-array series, if one.
            layered (bool): True if layered.
        """
        self.array_name = array_name
        self._storage = storage  # 'UNDEFINED', 'CONSTANT', 'ARRAY', 'TIME-ARRAY SERIES'
        self.time_array_series = time_array_series
        self.layered = layered  # True if layered
        self._layers: list[ArrayLayer] = []  # list of Array. Size of num layers if layered, else just 1.
        self.defined = True  # True if this array is defined.
        self._numeric_type = ''  # 'int' or 'float' - use numeric_type @property below
        self._log = log_util.get_logger()

    @property
    def storage(self):
        """Returns the storage.

        Returns:
            storage (str): 'UNDEFINED', 'CONSTANT', 'ARRAY', 'TIME-ARRAY SERIES'
        """
        return self._storage

    @storage.setter
    def storage(self, value):
        """Storage setter. Needed so that array gets updated appropriately.

        Args:
            value (str): The storage ('UNDEFINED', 'CONSTANT', 'ARRAY', 'TIME-ARRAY SERIES')
        """
        self._storage = value
        if self._layers and (value == 'CONSTANT' or value == 'ARRAY'):
            self._layers[0].storage = value

    @property
    def layers(self) -> list[ArrayLayer]:
        """Return the list of Array.

        Returns:
            See description.
        """
        return self._layers

    @property
    def numeric_type(self) -> str:
        """Return the numeric type for the array ('int', 'float').

        Returns:
            See description.
        """
        if self._numeric_type:
            return self._numeric_type
        else:
            # We currently store the numeric_type in the Array class
            if self._layers:
                self._numeric_type = self._layers[0].numeric_type
            return self._numeric_type

    @numeric_type.setter
    def numeric_type(self, type_str: str) -> None:
        """Set the numeric type for this array.

        Args:
            type_str: The numeric type ('int', 'float').
        """
        self._numeric_type = type_str
        # We currently store the numeric_type in the Array class
        for layer in self._layers:
            layer.numeric_type = type_str

    def layer(self, layer_idx: int | None = None) -> ArrayLayer | None:
        """Return the ArrayLayer for the given layer index (0-based) or the whole grid.

        If layer is not provided, or we are not layered, the first layer is returned.

        Args:
            layer_idx: The layer idx (0-based).

        Returns:
            See description.
        """
        if not self.layered or not layer_idx:
            return self._layers[0] if self._layers else None
        return self._layers[layer_idx] if len(self._layers) > layer_idx else None

    def set_layers(self, layers, layered) -> None:
        """Sets the layers.

        Args:
            layers (list[ArrayLayer]): The layers.
            layered (bool): True if LAYERED.
        """
        self._layers = layers
        self.layered = layered

    def get_values(self, apply_factor: bool = False) -> list:
        """Returns the list of values.

        Args:
            apply_factor: If True, the values are multiplied by the factor.

        Returns:
            (list): See description.
        """
        values = []
        for array_layer in self._layers:
            values.extend(array_layer.get_values(apply_factor))
        return values

    def set_values(self, values, shape: tuple[int, int], combine: bool) -> None:
        """Sets the values using one array (not layered) which must already exist.

        Args:
            values: The values.
            shape: tuple containing the first and second dimensions.
            combine: If True, combines new values with existing values, overwriting existing where new is not
             util.XM_NODATA.
        """
        if combine:
            values = self.combine_values(values)
        if self.layered:
            self._set_layered_array_values(values, shape)
        else:
            if len(self._layers) > 1:
                del self._layers[1:]
            self._layers[0].set_values(values, shape)
            self.storage = self._layers[0].storage

    def _set_layered_array_values(self, values, shape: tuple[int, int]) -> None:
        """Sets the values in each layer.

        Args:
            values: The values.
            shape: tuple containing the first and second dimensions.
        """
        num_values_per_layer = shape[0] * shape[1]
        layer_count = int(len(values) / num_values_per_layer)
        start = 0
        for lay in range(layer_count):
            values_lay = values[start:start + num_values_per_layer]
            if lay == len(self._layers):
                # Append a new layer
                array_layer = ArrayLayer(name=f'{self.array_name}_{lay + 1}', numeric_type=self._numeric_type)
                self._layers.append(array_layer)
            self._layers[lay].set_values(values_lay, shape, use_constant=True)
            start += num_values_per_layer

    def combine_values(self, new_values):
        """Combines new values with existing values, overwriting existing where new is not util.XM_NODATA.

        Args:
            new_values (list(float or int): The new values.

        Returns:
            (list): The combined values.
        """
        values = self.get_values()
        if len(values) != len(new_values):
            raise RuntimeError('Array.combine_values: sizes are not equal.')

        for i in range(len(new_values)):
            if new_values[i] != util.XM_NODATA:
                values[i] = new_values[i]
        return values

    def to_layered(self, shape: tuple[int, int]) -> None:
        """Converts data from not layered to layered.

        Args:
            shape: New shape of the array layers.
        """
        if self.layered or not self._layers:
            return
        self.layered = True

        # Get the values
        array0 = self._layers[0]  # Since it's initially unlayered, all values are in array0 to start
        values = array0.get_values()
        self._set_layered_array_values(values, shape)

    def to_unlayered(self) -> None:
        """Converts data from layered to not layered."""
        if not self.layered:
            return
        self.layered = False

        # Get sizes
        layered_shape = self._layers[0].shape
        num_values_per_layer = layered_shape[0] * layered_shape[1]
        nvalues = len(self._layers) * num_values_per_layer
        new_shape = (nvalues, 1)

        # Copy each layer's values into one list of values for the whole grid
        values = [1 if self.numeric_type == 'int' else 0.0] * nvalues
        i = 0
        for array in self._layers:
            if array.storage == 'CONSTANT':
                # Copy the constant value into every slot for this layer
                values[i:i + num_values_per_layer] = [array.constant] * num_values_per_layer
                i += num_values_per_layer
            else:
                array_vals = array.get_values()
                assert len(array_vals) == num_values_per_layer
                for c in range(num_values_per_layer):
                    values[c + i] = array_vals[c]
                i += num_values_per_layer

        # Set the array storage and values
        # Create a new array by copying the first layer
        new_array = copy.deepcopy(self._layers[0])
        new_array.external_filename = ''
        new_array.temp_external_filename = ''
        new_array.shape = new_shape
        new_array.set_values(values, new_shape, use_constant=True)

        # Replace all layers with just this one
        self._layers = []
        self._layers.append(new_array)

    def dataset_to_array(self, ts_data, shape: tuple[int, int], keep_const_layers: bool):
        """Applies dataset values to the array, without changing layered, and possible keeping CONSTANT layers.

        Args:
            ts_data (list of float): Values from an XMS dataset.
            shape: Shape of the array layer.
            keep_const_layers: What to do with layers that are CONSTANT.
        """
        if self.layered:
            num_values_per_layer = shape[0] * shape[1]
            for layer_idx, array_layer in enumerate(self.layers):
                if array_layer.storage == 'CONSTANT' and keep_const_layers:
                    new_values = [array_layer.constant] * num_values_per_layer
                else:
                    start = layer_idx * num_values_per_layer
                    end = start + num_values_per_layer
                    new_values = ts_data[start:end]
                array_layer.set_values(new_values, shape, use_constant=True)
        else:
            array_layer = self.layer(0)
            array_layer.set_values(ts_data, shape, use_constant=True)

    def dataset_to_layer(self, ts_data, layer: int, shape: tuple[int, int]):
        """Applies dataset values to the array, without changing layered, and possible keeping CONSTANT layers.

        Args:
            ts_data (list of float): Values from an XMS dataset.
            layer: The layer (1-based).
            shape: Shape of the array layer.
        """
        if layer > len(self.layers):
            raise IndexError('Layer does not exist.')

        array_layer = self.layer(layer - 1)
        num_values_per_layer = shape[0] * shape[1]
        start = (layer - 1) * num_values_per_layer
        end = start + num_values_per_layer
        array_layer.set_values(ts_data[start:end], shape, use_constant=True)

    def ensure_layer_exists(self, make_int=False, name='', constant=1.0):
        """If array.layers is empty, create an array_layer and put it in layers[0].

        Args:
            make_int (bool): If True and a new Array is created, numeric_type is set to 'int', otherwise 'float'.
            name (str): If a new Array is created, this will be its name.
            constant (float): Constant value to assign to the new array_layer if one is created.

        Returns:
            Returns the array_layer.
        """
        if not self._layers:
            array_layer = ArrayLayer(name=name, constant=constant)
            if make_int:
                array_layer.numeric_type = 'int'
            self._layers.append(array_layer)
        return self._layers[-1]

    def dump_to_temp_files(self, temp_file_list: list[str] | None = None) -> None:
        """Writes the layer arrays that have storage == 'ARRAY' to temp files.

        Args:
            temp_file_list: list of temporary file names
        """
        for array_layer in self._layers:
            if array_layer.storage != 'ARRAY':
                continue
            ArrayLayerWriter.write_array_layer(array_layer, array_layer.shape, temp_file_list)

    def dataframe_from_layer(self, layer_idx: int, grid_info: GridInfo) -> pd.DataFrame | None:
        """Return the dataframe for the array layer.

        Args:
            layer_idx: The layer idx (0-based).
            grid_info: Info about the grid.

        Returns:
            (DataFrame): The dataframe.
        """
        array_layer = self.layer(layer_idx)
        if not array_layer:
            return None

        # Create column names
        column_names = []
        if (self.layered or self.array_name.upper() == 'TOP') and grid_info.dis_enum == DisEnum.DIS:
            for col in range(grid_info.ncol):
                column_names.append(str(col + 1))
        else:
            column_names.append('')

        df = array_layer.to_dataframe()
        df.columns = column_names
        df.index += 1
        return df

    def __repr__(self):
        """Returns a string representation of the object to aid in debugging.

        Returns:
            (str): A string representation of the class object.
        """
        return (
            f'array_name: {self.array_name}, '
            f'storage: {self.storage}, '
            f'time_array_series: {self.time_array_series}, '
            f'layered: {self.layered}, '
            f'layers: {self._layers}'
        )

    def __eq__(self, other):
        """Returns True if self == other."""
        # yapf: disable
        return (
            self.array_name == other.array_name
            and self.storage == other.storage
            and self.time_array_series == other.time_array_series
            and self.layered == other.layered
            and self._layers == other._layers  # Uses Array.__eq__()
            and self.defined == other.defined
        )
        # yapf: enable

    def __deepcopy__(self, memo):
        """Returns a deep copy of this object.

        Args:
            memo: A memo object.

        Returns:
            See description.
        """
        new_array = Array(self.array_name, self._storage, self.time_array_series, self.layered)
        new_array.defined = self.defined
        new_array._numeric_type = self._numeric_type
        new_array._layers = copy.deepcopy(self._layers)
        return new_array

    @staticmethod
    def read(
        fp: typing.TextIO | TextFile, line: str, valid_names: list[str], sizes_callback, base_filepath,
        numeric_types: dict[str, str], importing: bool
    ) -> Array | None:
        """Read and return the Array.

        Args:
            fp: File object.
            line: Line in file that is the start of the array data.
            valid_names: Valid array names.
            sizes_callback: Callback that, given the name and layered, returns nvalues, layers_to_read, and shape.
            base_filepath: Filepath of package being read, used to resolve relative paths.
            numeric_types: Dict of names to numeric types (e.g. {'DELR': 'float, 'IDOMAIN': 'int'}
            importing: Flag to tell if we need to read external file data.

        Returns:
            See description.
        """
        words = line.split()
        if not words:
            raise ValueError('Could not find array name.')
        if words[0].upper() not in valid_names:
            raise ValueError(f'{words[0]} is not a valid array name.')
            # return None  # pragma no cover - error condition

        # Get name and layered
        name = words[0].upper()
        layered = False
        if len(words) > 1 and words[1].upper() == 'LAYERED':
            layered = True

        # Create Array and read all the layers
        array = Array(array_name=name, layered=layered)
        array.numeric_type = numeric_types[name]
        _read_layers(fp, array, sizes_callback, base_filepath, importing)
        return array


def _read_layers(fp: typing.TextIO | TextFile, array: Array, sizes_callback, base_filepath, importing: bool):
    """Reads nlay layers of data.

    Args:
        fp: File object.
        array: Array object
        sizes_callback: Callback that, given the name and layered, returns nvalues, layers_to_read, and shape.
        base_filepath: Filepath of package being read, used to resolve relative paths.
        importing: Flag to tell if we need to read external file data.

    Returns:
        array: array data
    """
    nvalues, nlay, shape = sizes_callback(array.array_name, array.layered)
    for _i in range(nlay):
        reader = ArrayLayerReader()
        try:
            array_layer, _ = reader.read(base_filepath, nvalues, fp, importing)
        except Exception:
            raise ValueError(f'Could not read "{array.array_name}" array layer {_i + 1}')
        array_layer.numeric_type = array.numeric_type
        array_layer.name = array.array_name
        array_layer.shape = shape
        array.layers.append(array_layer)
