"""GriddataBase class."""
from __future__ import annotations  # So we can use Array type hint in a static method

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.mf6.data.array import Array
from xms.mf6.data.array_layer import ArrayLayer
from xms.mf6.data.base_file_data import BaseFileData
from xms.mf6.data.block import Block
from xms.mf6.data.grid_info import DisEnum
from xms.mf6.gui import units_util


class GriddataBase(BaseFileData):
    """Data class to hold the info from GRIDDATA blocks (DIS, DISV, DISU, IC, NPF, STO)."""
    def __init__(self, **kwargs):
        """Initializes the class."""
        super().__init__(**kwargs)

        self._blocks: dict[str, Block] = {}  # 'GRIDDATA' -> Block etc.

    def block(self, name: str) -> Block | None:
        """Return the block with the given name (or None if not found).

        Returns:
            See description.
        """
        return self._blocks.get(name)

    @property
    def blocks(self) -> dict[str, Block]:
        """Return the blocks.

        Returns:
            See description.
        """
        return self._blocks

    def add_block(self, name: str, array_names: list[str]) -> None:
        """Adds and initializes the block.

        Args:
            name: Block name (e.g. 'GRIDDATA')
            array_names: List of array names (e.g. ['DELR', 'BOTM'])
        """
        # Default to layered unless doing DISU
        layered_default = True
        try:  # If DIS*, grid_info won't be set yet
            if self.grid_info() and self.grid_info().dis_enum == DisEnum.DISU:
                layered_default = False
        except RuntimeError:
            pass

        block = Block(name, array_names)
        for array_name in array_names:
            if self.is_required_array(array_name):
                layered = layered_default if self.can_be_layered(array_name) else False
                array = self.new_array(array_name, layered=layered)
                block.add_array(array)
        self._blocks[name] = block

    def new_array(self, name: str, layered: bool) -> Array:
        """Initialize and return a new Array.

        Args:
            name: The Array name, like 'DELR' or 'BOTM'.
            layered: True if layered.

        Returns:
            See description.
        """
        array = Array(
            array_name=name,
            layered=layered,
        )
        array.numeric_type = 'int' if self.is_int_array(name) else 'float'
        constant = 1.0 if self.is_int_array(name) else 0.0

        # Get grid_info if possible
        grid_info = None
        try:  # If DIS*, grid_info won't be set yet
            grid_info = self.grid_info()
        except Exception:
            pass

        # Determine nlay
        nlay = 1 if not layered or not grid_info or grid_info.nlay < 1 else grid_info.nlay

        _, _, shape = self.array_size_and_layers(name, layered)
        for _ in range(nlay):
            array.layers.append(ArrayLayer(name=name, numeric_type=array.numeric_type, constant=constant, shape=shape))
        return array

    def add_array(self, block_name: str, array_name: str, layered: bool, replace_existing=True) -> Array:
        """Adds Array object with name array_name.

        Args:
            block_name: Name of the block to add the array to.
            array_name: array name (e.g. 'IRCH', 'RECHARGE').
            layered: True if array should be layered.
            replace_existing: True to replace any data already associated with the array_name.

        Returns:
            (Array): The array data.
        """
        block = self._blocks.get(block_name)
        if replace_existing or array_name not in block.existing_names:
            array = self.new_array(array_name, layered)
            block.add_array(array)
        return block.array(array_name)

    def get_units(self, array_name: str) -> str:
        """Returns the units string for the units that belong to the array, like 'L' or 'L^3/T'.

        You should override this method.

        Args:
            array_name (str): The name of a array.

        Returns:
            (str): See description.
        """
        raise NotImplementedError()

    def is_int_array(self, array_name) -> bool:
        """Returns True if the array is integers.

        You should override this method.

        Args:
            array_name (str): The name of a array.

        Returns:
            (bool): True or False
        """
        raise NotImplementedError()
        # return ''

    def is_required_array(self, array_name: str) -> bool:
        """Returns True if the array is required.

        You should override this method.

        Args:
            array_name (str): The name of a array.

        Returns:
            (bool): True or False
        """
        raise NotImplementedError()
        # return ''

    def can_be_layered(self, array_name: str) -> bool:
        """Returns True if the array can have the LAYERED option.

        You should override this method.

        Args:
            array_name (str): The name of a array.

        Returns:
            (bool): True or False
        """
        return True

    def can_be_dataset(self, array_name: str) -> bool:
        """Returns True if the array can have the LAYERED option.

        Args:
            array_name (str): The name of an array.

        Returns:
            (bool): True or False
        """
        return True

    def array_size_and_layers(self, array_name: str, layered: bool) -> tuple[int, int, tuple]:
        """Returns True if the array can be LAYERED.

        Args:
            array_name (str): The name of a array.
            layered (bool): True if layered keyword was read.

        Returns:
            (tuple): tuple containing:
                - nvalues (int): Number of values to read.
                - layers_to_read (int): Number of layers to read.
                - shape
        """
        grid_info = self.grid_info()
        nvalues, shape = ArrayLayer.number_of_values_and_shape(layered, grid_info)
        layers_to_read = 0
        if grid_info:
            layers_to_read = grid_info.nlay if layered and grid_info.nlay > 0 else 1
        return nvalues, layers_to_read, shape

    def dialog_title(self) -> str:
        """Returns the title to show in the dialog.

        You should override this method.

        Returns:
            (str): The dialog title.
        """
        raise NotImplementedError()
        # return ''

    def nparray_from_array(self, array_name):
        """Returns a numpy array given a array.

        Args:
            array_name (str): Name of the array (e.g. 'DELR', 'DELC')

        Returns:
            See description.
        """
        array = self.block('GRIDDATA').array(array_name)
        np_array = None
        if array:
            np_array = np.array(array.get_values(), dtype='f')
        return np_array

    def get_tool_tip(self, tab_text: str) -> str:
        """Returns the tool tip that goes with the tab with the tab_name.

        Args:
            tab_text: Text of the tab

        Returns:
            (str): The tool tip.
        """
        return ''

    def make_datasets(self, array_names: list[str], ugrid_uuid: str) -> list[DatasetWriter]:
        """Create and return DatasetWriters for each array listed in array_names.

        Args:
            array_names: Names of the arrays to create datasets from.
            ugrid_uuid: uuid of the UGrid.

        Returns:
            List of DatasetWriter.
        """
        from xms.mf6.components import arrays_to_datasets as atd  # Imported here to avoid circular import

        grid_info = self.grid_info()
        cell_count = grid_info.cell_count()
        writers: list[DatasetWriter] = []
        for name in array_names:
            array = self.block('GRIDDATA').array(name)

            # Make sure the size of the dataset is equal to the number of grid cells
            values = atd.resize_values(cell_count, array.get_values(apply_factor=True))

            # Create activity array if necessary
            activity = atd.make_activity(values)
            activity_list = [activity] if activity is not None else None

            units = units_util.string_from_units(self, self.get_units(name))
            writer = atd.create_dataset_writer(name, ugrid_uuid, ref_time=None, time_units='None', units=units)
            writer.write_xmdf_dataset(times=[0.0], data=[values], activity=activity_list)
            writers.append(writer)
        return writers
