"""Module for the TransientDatasetExtractor class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['TransientDatasetExtractor']

# 1. Standard Python modules
from functools import cached_property
from itertools import count
from typing import cast, Optional

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import UGrid2d
from xms.datasets.dataset_reader import DatasetReader
from xms.extractor import UGrid2dDataExtractor

# 4. Local modules


class TransientDatasetExtractor:
    """
    Class for extracting all the data for a set of locations across multiple time steps of a transient dataset.

    Should be constructed, then `self.add_locations()` called at least once, then `self.extract()` called. Afterward,
    data can be extracted with `self.scalars_for()`.
    """
    def __init__(self, ugrid: UGrid2d, dataset: DatasetReader, time_indexes: list[int]):
        """
        Initialize the extractor.

        Args:
            ugrid: Geometry to extract from.
            dataset: Dataset defining scalars on the geometry to extract from. May be transient.
            time_indexes: Indexes of time steps to extract.
        """
        self._ugrid = ugrid
        self._dataset = dataset
        self._mapping = {}
        self._locations: list[tuple[float, float, float]] = []

        self._values: Optional[np.ndarray] = None
        self._activity: Optional[np.ndarray] = None
        self._indexes: list[int] = time_indexes

        self._extracted_scalars: Optional[np.ndarray] = None
        self._extracted = False

    def add_locations(
        self, locations: list[tuple[float, float, float]], identifiers: Optional[list[str] | list[int]] = None
    ):
        """
        Add some locations to be extracted.

        May not be called after `self.extract()`.

        Args:
            locations: List of (x, y, z) tuples defining locations to extract.
            identifiers: How to identify the locations. If omitted, each location's identifier will be its index in the
                list of locations that were added. If this is called multiple times, subsequent calls' identifiers will
                start after previous calls.
        """
        if self._extracted:
            raise AssertionError('Cannot add locations after extracting.')

        if identifiers:
            next_index = len(self._mapping)
            indexes = count(start=next_index)
            self._mapping.update(zip(identifiers, indexes))
        self._locations.extend(locations)

    @cached_property
    def _extractor(self):
        """The extractor."""
        extractor = UGrid2dDataExtractor(self._ugrid.ugrid)
        return extractor

    def _load_values(self):
        """
        Load the values from a dataset.

        The loaded data will be restructured to be more convenient for internal use.
        """
        # self._dataset.values is some h5py wrapper around a numpy array. It doesn't support comparison the same way
        # numpy does.
        values = np.array(self._dataset.values)

        # Normalize to two components
        if self._dataset.num_components == 1:
            values = np.expand_dims(values, -1)
        # Reorder axes from (time_step, node, component) to (component, time_step, node).
        # The class' user will want all the time steps for a node, which would suggest putting time_step last so it's
        # easy to get them. These are put into the extractor though, which wants all the nodes for a time step, so we
        # put them in that order here instead.
        values = values.transpose([2, 0, 1])
        self._values = values

    def _load_activity(self):
        """Load/create the activity for the dataset."""
        if self._dataset.activity is not None:
            # self._dataset.activity is some h5py wrapper around a numpy array. It doesn't support comparison the same
            # way numpy does.
            activity = np.array(self._dataset.activity)
            activity = np.expand_dims(activity, -1)
            if self._dataset.num_components == 2:
                activity = np.concatenate([activity, activity], axis=-1)
            # Reorder axes from (time_step, node, component) to (component, time_step, node) to match self._values.
            activity = activity.transpose([2, 0, 1])
        elif self._dataset.null_value is not None:
            activity = self._values != self._dataset.null_value
        else:
            activity = np.ones(self._values.shape)

        self._activity = activity
        assert self._activity.shape[:2] == self._values.shape[:2]

    def _set_up_output(self):
        """Set up the output array."""
        num_indexes = len(self._indexes)
        num_locations = len(self._locations)
        self._extracted_scalars = np.empty((2, num_indexes, num_locations), dtype=float)

    def extract(self):
        """Run the extractor and extract all the scalars."""
        self._extracted = True
        self._load_values()
        self._load_activity()
        self._set_up_output()
        self._extractor.extract_locations = self._locations
        for component in range(self._dataset.num_components):
            for dest_index, source_index in enumerate(self._indexes):
                self._set_up_extractor_for(component, source_index)
                self._extract_time_step(component, dest_index)

        # This is initialized with axes (component, time_step, node), which is convenient for use with the internal
        # extractor that expects to get all the nodes for a time step, but it's inconvenient for our own user, which
        # wants all the time steps for a node. We move time_step to the end here for the convenience of
        # self.scalars_for().
        self._extracted_scalars = self._extracted_scalars.transpose([0, 2, 1])

    def scalars_for(self, identifier: str | int, component: int = 0) -> np.ndarray[float]:
        """
        Get the scalars for a location.

        May not be called before `self.extract()`.

        Args:
            identifier: Identifier for the location. If not passed to self.add_locations, will be the location's index
                in the list of added locations.
            component: Component of the dataset to get the values for. Scalar datasets only support values of 0.
                Vector datasets support 0 and 1.

        Returns:
            Sequence of all the values extracted for the given location.
        """
        if not self._extracted:
            raise AssertionError('Cannot get scalars before extracting.')

        index = self._mapping[identifier] if self._mapping else identifier
        scalars = self._extracted_scalars[component][index]
        scalars = cast(np.ndarray[float], scalars)
        return scalars

    def _set_up_extractor_for(self, component: int, index: int):
        """
        Set up the extractor for a time step.

        Args:
            component: Component of the dataset to extract.
            index: Index of the time step to set up for.
        """
        activity = self._activity[component][index]
        scalars = self._values[component][index]

        if self._dataset.location == 'points' and len(activity) == len(scalars):
            self._extractor.set_grid_point_scalars(scalars, activity, 'points')
        elif self._dataset.location == 'points':
            self._extractor.set_grid_point_scalars(scalars, activity, 'cells')
        else:
            self._extractor.set_grid_cell_scalars(scalars, activity, 'cells')

    def _extract_time_step(self, component: int, index: int):
        """
        Extract the scalars for a time step.

        Args:
            component: Component of the dataset to extract.
            index: Index of the time step to extract.
        """
        extracted_scalars = self._extractor.extract_data()
        self._extracted_scalars[component][index] = extracted_scalars
