"""ArraysToDatasets class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from datetime import datetime

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.datasets.dataset_writer import DatasetWriter
from xms.guipy.dialogs.xms_parent_dlg import XmsDlg
from xms.guipy.testing import testing_tools

# 4. Local modules
from xms.mf6.components import dmi_util
from xms.mf6.data.array import Array
from xms.mf6.data.grid_info import GridInfo
from xms.mf6.file_io import io_util
from xms.mf6.gui import arrays_to_datasets_dialog
from xms.mf6.misc.util import XM_NODATA


def ask_and_create(package_dlg: XmsDlg) -> bool:
    """Ask user for arrays and create UGrid datasets from arrays.

    Args:
        package_dlg: The package dialog calling this function.
    """
    creator = ArraysToDatasets(package_dlg)
    return creator.ask_and_create()


class ArraysToDatasets:
    """A Dynamic Model Interface (DMI) component for a GWF (Groundwater Flow) model."""
    def __init__(self, package_dlg: XmsDlg):
        """Initializes the class.

        Args:
            package_dlg: The package dialog calling this function.
        """
        self._package_dlg = package_dlg

    def ask_and_create(self) -> bool:
        """Ask user for arrays and create UGrid datasets from arrays."""
        data = self._package_dlg.dlg_input.data
        array_names = arrays_to_datasets_dialog.run(data, self._package_dlg)
        if array_names:
            ugrid_uuid = dmi_util.ugrid_uuid_from_model_node(data.model.tree_node)
            self._create_scalar_datasets(array_names, ugrid_uuid, self._package_dlg.dlg_input.query)
            return True
        return False

    def _create_scalar_datasets(self, array_names: list[str], ugrid_uuid: str, query: Query) -> None:
        """Creates the datasets from the arrays.

        Args:
            array_names: List of arrays to create datasets from.
            ugrid_uuid: uuid of the UGrid.
            query: Object for communicating with GMS
        """
        writers = self._package_dlg.dlg_input.data.make_datasets(array_names, ugrid_uuid)
        for writer in writers:
            query.add_dataset(writer)


def resize_values(cell_count: int, values: list[float]) -> list[float]:
    """Return the values resized to the number of cells in the grid.

    Args:
        cell_count: Number of cells in the grid.
        values: The array values.

    Returns:
        The resized list of values.
    """
    if len(values) < cell_count:
        values.extend([XM_NODATA] * (cell_count - len(values)))
    return values


def make_activity(values: list[float | int]) -> list[int] | None:
    """Return a list of the activity, if array is not the layer indicator array.

    Args:
        values: The array values.

    Returns:
        See description.
    """
    return [0 if value == XM_NODATA else 1 for value in values]


def create_dataset_writer(
    name: str, geom_uuid: str, ref_time: float | datetime | None, time_units: str | None, units: str
) -> DatasetWriter:
    """Create and return a DatasetWriter.

    Args:
        name: Dataset name.
        geom_uuid: Uuid of the UGrid.
        ref_time: Starting date/time of the dataset.
        time_units: Time units.
        units: Units of the values.

    Returns:
        See description.
    """
    temp_filename = io_util.get_temp_filename(suffix='.h5')
    writer = DatasetWriter(
        h5_filename=temp_filename,
        name=name,
        dset_uuid=testing_tools.new_uuid(),
        geom_uuid=geom_uuid,
        ref_time=ref_time,
        time_units=time_units,
        units=units,
        use_activity_as_null=True,
        location='cells'
    )
    return writer


def values_from_array(array: Array, grid_info: GridInfo) -> list:
    """Returns a list of values from the array.

    Args:
        array: array data.
        grid_info: grid info

    Returns:
        See description.
    """
    values = []
    for layer_idx, array_layer in enumerate(array.layers):
        if array_layer.storage == 'CONSTANT':
            values.extend([array_layer.constant] * (array_layer.shape[0] * array_layer.shape[1]))
        else:
            df_layer = array.dataframe_from_layer(layer_idx, grid_info)
            for row in range(df_layer.shape[0]):
                for col in range(df_layer.shape[1]):
                    values.append(df_layer.iloc[row, col] * array_layer.factor)
    return values
