"""Module for performing statistical analysis on SRH-2D model run results."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import os
import uuid
import warnings

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.api.tree import tree_util
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.srh.components.sim_query_helper import SimQueryHelper


class Statistician:
    """Run statistical analysis on model run results."""
    SRH_NULL_VALUE = -999.0

    def __init__(self, datasets=None, temp_dir=None, query=None):
        """Initializes the class.

        Args:
            datasets (:obj:`dict`): Dict of the xms.data_objects.parameters.Dataset solutions. If not provided,
                will retrieve and send data to/from SMS. Key should be solution name value should be list of
                DatasetReaders.
            temp_dir (:obj:`str`): Path to the XMS temp directory
            query (:obj:`xms.api.dmi.Query`): Object for communicating with SMS. Mutually exclusive with datasets. Pass
                in if not testing.
        """
        self._logger = logging.getLogger('xms.srh')
        self._stats = []
        self._dset_name = ''  # Name of current solution dataset being processed
        self._query = query
        self._temp_dir = temp_dir
        self._datasets = datasets if datasets else {}  # All solution datasets, keyed by name
        self._mesh_uuid = None
        self._time_units = None
        self._ts_data = None  # Current timestep data
        self._stat_times = []  # Output dataset times
        self._stat_data = {}  # Output dataset data
        self._solution_tree_item = None
        self._mesh_tree_item = None
        self._stats_folder_name = None
        self._stats_folder_tree_item = None

        # For testing
        self.out_filenames = []
        self.out_uuids = []

        if not self._datasets and query:
            self._get_solutions()  # pragma: no cover

    def _get_out_filename(self):
        """Either get a hard-coded output filename (testing) or a randomly generated one in XMS temp."""
        if self.out_filenames:
            return self.out_filenames.pop()
        else:  # pragma: no cover
            return os.path.join(self._query.xms_temp_directory, f'{uuid.uuid4()}.h5')

    def _get_out_uuid(self):
        """Either get a hard-coded output dataset UUID (testing) or a randomly generated one."""
        if self.out_uuids:
            return self.out_uuids.pop()
        else:  # pragma: no cover
            return str(uuid.uuid4())

    def _get_solutions(self):  # pragma: no cover
        """Query XMS for solution datasets and group by name."""
        helper = SimQueryHelper(self._query)
        solutions = helper.get_solution_data()
        if not solutions:
            return
        self._solution_tree_item = helper.solution_tree_item
        self._mesh_tree_item = helper.mesh_tree_item
        # see if the statistics folder already exists
        self._stats_folder_name = f'STATS: {self._solution_tree_item.name}'
        folders = tree_util.descendants_of_type(self._mesh_tree_item, xms_types='TI_FOLDER')
        for fold in folders:
            if fold.name == self._stats_folder_name:
                self._stats_folder_tree_item = fold
        for solution in solutions:
            sol_name = solution.name
            pos = sol_name.rfind('(')
            if pos > 0:
                sol_name = sol_name[0:pos]
            sol_name = sol_name.strip()
            dset_list = self._datasets.setdefault(sol_name, [])
            dset_list.append(solution)

    def _load_timestep(self, ts_idx):
        """Load a solution dataset timestep into memory for all runs in the simulation.

        Args:
            ts_idx (:obj:`int`): Timestep index to load.

        """
        ts_time = None
        self._ts_data = []
        for dset in self._datasets[self._dset_name]:
            if dset.num_times > ts_idx:
                ts_time = dset.times[ts_idx]
                self._ts_data.append(dset.values[ts_idx])
            else:  # This run did not converge, fill in with inactive values
                self._ts_data.append(np.full(dset.num_values, self.SRH_NULL_VALUE))
        self._stat_times.append(ts_time)
        self._ts_data = np.array(self._ts_data)
        self._ts_data[self._ts_data == self.SRH_NULL_VALUE] = np.nan

    def _add_datasets(self):
        """Write statistical analysis XMDF files and create xms.data_objects Datasets to send back to SMS."""
        times = np.array(self._stat_times)
        for dset_name, dset_vals in self._stat_data.items():
            # Write the XMDF file
            xmdf_filename = self._get_out_filename()  # Random filename or hard-coded if testing
            dset_uuid = self._get_out_uuid()
            writer = DatasetWriter(
                xmdf_filename,
                name=dset_name,
                dset_uuid=dset_uuid,
                geom_uuid=self._mesh_uuid,
                null_value=self.SRH_NULL_VALUE,
                time_units=self._time_units
            )
            writer.write_xmdf_dataset(times=times, data=dset_vals)
            self._stats.append(writer)

    def _compute_min(self):
        """Compute the minimum value at each location across all runs for the current time step."""
        mins = np.nanmin(self._ts_data, axis=0)
        min_dset = self._stat_data.setdefault(f'min_{self._dset_name}', [])
        min_dset.append(mins)

    def _compute_max(self):
        """Compute the maximum value at each location across all runs for the current time step."""
        maxs = np.nanmax(self._ts_data, axis=0)
        max_dset = self._stat_data.setdefault(f'max_{self._dset_name}', [])
        max_dset.append(maxs)

    def _compute_mean(self):
        """Compute the mean value at each location across all runs for the current time step."""
        means = np.nanmean(self._ts_data, axis=0)
        mean_dset = self._stat_data.setdefault(f'mean_{self._dset_name}', [])
        mean_dset.append(means)

    def _compute_std_dev(self):
        """Compute the standard deviation value at each location across all runs for the current time step."""
        stdevs = np.nanstd(self._ts_data, axis=0)
        stdev_dset = self._stat_data.setdefault(f'stdev_{self._dset_name}', [])
        stdev_dset.append(stdevs)

    def run_statistics(self):
        """Compute statistical datasets from an SRH-2D solution."""
        if not self._datasets:
            self._logger.error(
                'Cannot run statistical analysis because no solution datasets were found for the simulation.'
            )
            return
        if self._stats_folder_tree_item is not None:
            tpath_names = [self._stats_folder_tree_item.name]
            tree_item = self._stats_folder_tree_item
            while tree_item.parent:
                tree_item = tree_item.parent
                tpath_names.append(tree_item.name)
            tpath_names.reverse()
            tpath = '/'.join(tpath_names)
            self._logger.warning(f'Statistics already exist for this simulation: ({tpath}). Aborting.')
            return

        # make sure we have multiple datasets for each dataset name
        remove_ds = []
        for name, dsets in self._datasets.items():
            if len(dsets) == 1:
                remove_ds.append(name)
        for name in remove_ds:
            self._datasets.pop(name)
        if len(self._datasets) == 0:
            msg = (
                'Statistical analysis requires multiple solutions '
                '(SRH Advanced simulation - Run type = "Scenarios"). Aborting.'
            )
            self._logger.error(msg)

        with warnings.catch_warnings():  # Ignore numpy warning about all-nan slices.
            warnings.filterwarnings(action='ignore', message='All-NaN slice encountered')
            warnings.filterwarnings(action='ignore', message='Mean of empty slice')
            warnings.filterwarnings(action='ignore', message='Degrees of freedom <= 0 for slice')
            # Create statistical datasets for each of the solution datasets
            for name, dsets in self._datasets.items():
                self._logger.info(f'Calculating statisitics for data set: {name}.')
                if 'Velocity_' in name:
                    continue
                self._dset_name = name
                self._mesh_uuid = dsets[0].geom_uuid
                self._time_units = dsets[0].time_units
                self._stat_times = []
                self._stat_data = {}
                num_ts = max([dset.num_times for dset in dsets])
                for i in range(num_ts):
                    self._load_timestep(i)
                    self._compute_min()
                    self._compute_max()
                    self._compute_mean()
                    self._compute_std_dev()
                self._add_datasets()

    def send(self):  # pragma: no cover
        """Send computed statistical data sets to SMS."""
        if self._query:
            for dset in self._stats:
                dset_kwargs = {
                    'do_dataset': dset,
                    'folder_path': self._stats_folder_name,
                }
                self._query.add_dataset(**dset_kwargs)
