"""Dataset Transient Calculator Algorithm."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from logging import Logger
import time
from typing import Any

# 2. Third party modules
import numpy as np
from numpy import dtype, ndarray

# 3. Aquaveo modules
from xms.constraint.ugrid_activity import active_points_from_cells
from xms.core.filesystem import filesystem
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter
from xms.grid.ugrid.ugrid import UGrid
from xms.tool_core import dataset_argument

# 4. Local modules

operation_dict = {"Minimum": np.nanmin, "Maximum": np.nanmax, "Standard Deviation": np.nanstd, "Mean": np.nanmean,
                  "Median": np.nanmedian}


class TransientDatasetStatisticCalculator:
    """Class for performing statistical operations on a dataset."""

    def __init__(self, dataset: DatasetReader | None, operations: list[dataset_argument] | None,
                 output_datasets: list[DatasetWriter] | None, ugrid: UGrid | None, logger: Logger | None):
        """
        Initializer for the transient statistical calculator.

        Args:
            dataset: a dataset reader object
            operations: operation list to be performed
            output_datasets: list of dataset writer objects.
            ugrid: dataset ugrid
            logger: logger
        """
        self.logger = logger
        self._dataset = dataset
        self._operations = operations
        self._output_datasets = output_datasets
        self._ugrid = ugrid
        self._activity = None
        self._ds_vals_cache = None
        self._ds_activity_cache = None
        self._cache_size = None
        self._cache_start = 0
        self.start_time = 0.0

    def _ds_vals_from_idx(self, idx: int) -> list[int]:
        """
        Loads blocks of dataset and activity values into memory.

        Args:
            idx: index used to get blocks

        Returns:
            self._ds_vals_cache(list): list of cached values
        """
        if self._ds_vals_cache is None:
            self._ds_vals_cache = self._dataset.values[:, self._cache_start:self._cache_start + self._cache_size]

        ix = idx - self._cache_start
        if ix >= self._cache_size:  # load next block
            self._cache_start = self._cache_start + self._cache_size
            self._ds_vals_cache = self._dataset.values[:, self._cache_start:self._cache_start + self._cache_size]
            if self._ds_activity_cache is not None:
                self._ds_activity_cache = self._activity[:, self._cache_start:self._cache_start + self._cache_size]
            ix = idx - self._cache_start
        return self._ds_vals_cache[:, ix]

    def _ds_activity_from_idx(self, idx: int) -> list[int]:
        """
        Access specific slices of activity data from the cache.

        Args:
            idx: index used to get activity data

        Returns:
            self._ds_activity_cache(list): block of activity within the cache range
        """
        ix = idx - self._cache_start
        if self._ds_activity_cache is None:
            self._ds_activity_cache = self._activity[:, self._cache_start:self._cache_start + self._cache_size]
        return self._ds_activity_cache[:, ix]

    def _fill_inactive_with_nan(self, vals_all_ts: list[int], idx: int) -> ndarray[Any, dtype[Any]] | list[int]:
        """
        Replaces all dataset null values with nan.

        Args:
            vals_all_ts: list of values for all timesteps
            idx: index used to separate activity into blocks

        Returns:
            vals_all_ts(list): list of cleaned values
        """
        if self._dataset.null_value is not None:
            return np.where(vals_all_ts == self._dataset.null_value, np.nan, vals_all_ts)
        elif self._activity:
            return np.where(self._ds_activity_from_idx(idx), vals_all_ts, np.nan)
        else:
            return vals_all_ts

    def _timed_logging(self, interval: float, message: str, idx: int) -> None:
        """
        Logs a message every specified interval.

        Args:
            interval: interval to log
            message: the message to log
            idx: the index or time step to log in the message

        """
        if time.time() - self.start_time > interval:
            self.start_time = time.time()
            self.logger.info(f'{message} {idx + 1}.')

    def _preprocess_activity(self) -> None:
        """
        Processes and stores activity for all timesteps in a temporary .h5 file.
        """
        if self._dataset.null_value is not None:
            return
        elif self._dataset.activity is not None:
            if self._dataset.num_values == self._dataset.num_activity_values:
                self._activity = self._dataset.activity
            else:  # process activity for all timesteps and store in a temp h5 file
                num_ts = self._dataset.num_times
                msg = (f'Point data with cell activity requires activity to be preprocessed for all time steps '
                       f'({num_ts}).')
                self.logger.info(msg)
                temp_file = filesystem.temp_filename(suffix='.h5')
                writer = DatasetWriter(temp_file)
                self.start_time = time.time()
                for ts in range(num_ts):
                    self._timed_logging(3.5, 'Preprocessing activity, current time step', ts)
                    new_act = active_points_from_cells(self._ugrid, self._dataset.activity[ts])
                    writer.append_timestep(float(ts), new_act)
                writer.appending_finished()
                reader = DatasetReader(h5_filename=temp_file, dset_name="Dataset")
                self._activity = reader.values
                self.logger.info('Preprocessing activity complete.')

    def calculate_dataset(self) -> list[list[float]]:
        """
        Performs calculation on datasets.

        Returns:
            result(list): nested list of result values
        """
        # 10,000,000 float_32's is ~40 Megabytes
        self._cache_size = 10_000_000 // self._dataset.num_times
        self._preprocess_activity()

        operations = [operation_dict.get(operation.description) for operation in self._operations]
        result = [[] for _ in range(len(operations))]
        self.start_time = time.time()
        self.logger.info(f'Processing dataset values, number of values {self._dataset.num_values}.')
        for idx in range(self._dataset.num_values):
            self._timed_logging(3.5, 'Processing dataset values, current index', idx)
            vals_all_ts = self._fill_inactive_with_nan(self._ds_vals_from_idx(idx), idx)
            for i, operation in enumerate(operations):
                result[i].append(operation(vals_all_ts))

        # if NAN is in result then process activity
        for i in range(len(result)):
            if np.any(np.isnan(result[i])):
                result[i] = np.where(np.isnan(result[i]), self._output_datasets[i].null_value, result[i]).tolist()
        return result

    def calculate(self) -> None:
        """Perform calculation with provided operation."""
        results = self.calculate_dataset()

        for i in range(len(results)):
            self._output_datasets[i].append_timestep(0.0, results[i])
            self._output_datasets[i].appending_finished()
