"""ComponentData class."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
from contextlib import suppress
import os
from pathlib import Path
from typing import Callable, Optional, Sequence, TypeAlias

# 2. Third party modules
import numpy as np
from packaging.version import parse as parse_version
import pandas as pd
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem.filesystem import removefile

# 4. Local modules
from xms.gmi.__version__ import version
from xms.gmi.components.utils import new_component_dir
from xms.gmi.data.generic_model import GenericModel

CURVE_DATASET = 'curves'  #: Name of the curve dataset included in `ComponentData` by default.
COMP_ID_KEY = 'comp_id'  #: Name of the component ID `DataArray` in a `ComponentData` dataset.
TYPE_KEY = 'type'  #: Name of the feature type `DataArray` in a `ComponentData` dataset.
VALUE_KEY = 'data_dict'  #: Name of the values `DataArray` in a `ComponentData` dataset.

#: Type alias for a callable that gets curves.
#:
#: `ComponentData.get_curve` should match this.
#:
#: Args:
#:     curve_id: ID of the curve to get. This should have come from a `CurveAdder`.
#:     use_dates: Whether the curve is expected to use dates.
#: Returns:
#:     A tuple of (x, y) where `x` is a `Sequence` of X values,
#:     and `y` is a `Sequence` of Y values.
#:     Values in `x` will be of type `np.datetime64` if `use_dates` is True, else `float`.
CurveGetter: TypeAlias = Callable[[int, bool], tuple[Sequence[float | np.datetime64], Sequence[float]]]

#: Type alias for a callable that adds curves.
#:
#: `ComponentData.add_curve` should match this.
#:
#: Args:
#:     x: A `Sequence` of either `float` (if `use_dates` is False) or `np.datetime64` (if `use_dates` is True)
#:         representing the X values in the curve.
#:     y: A `Sequence` of `float` representing the Y values of the curve.
#:     use_dates: Whether this curve uses dates.
#:
#: Returns:
#:     An integer identifying the newly added curve. This can be passed to a `CurveGetter` to retrieve the curve later.
CurveAdder: TypeAlias = Callable[[Sequence[float | np.datetime64], Sequence[float], bool], int]


def make_xy_dataset(
    curve_id: int = 0,
    x: Sequence[float | np.datetime64] | None = None,
    y: Sequence[float] | None = None,
    use_dates: bool = False
):
    """
    Make a new curve dataset.

    An empty default dataset can be created by not passing any parameters.

    Args:
        curve_id: ID of the new curve to create.
        x: X values for the curve. Should be of type `np.datetime64` if using dates, else `float`.
        y: Y values for the curve.
        use_dates: Whether the curve uses dates or floats.
    """
    x = x or []
    y = y or []
    xy_id = [curve_id] * len(x)
    if use_dates:
        data_vars = {
            'dt': ('xy_id', np.array(x, dtype=np.datetime64)),
            'y': ('xy_id', np.array(y, dtype=np.float64)),
        }
    else:
        data_vars = {
            'x': ('xy_id', np.array(x, dtype=np.float64)),
            'y': ('xy_id', np.array(y, dtype=np.float64)),
        }
    coords = {'xy_id': np.array(xy_id, dtype=int)}
    return xr.Dataset(data_vars=data_vars, coords=coords)


class ComponentData(XarrayBase):
    """
    Data class for components.

    `ComponentData` manages the data for a component. By default, it manages a single curve dataset. Derived classes can
    add new datasets by overriding `self._dataset_names` and `self._get_dataset`.

    `ComponentData` can also store a `GenericModel` template.

    Derived classes are likely to be more useful than `ComponentData` alone.
    """
    def __init__(self, main_file: Optional[str | Path] = None):
        """
        Initialize the class.

        Args:
            main_file: Path to the component's data file. If not provided, a default one will be chosen.
        """
        if not main_file:
            main_file = new_component_dir() / self._main_file_name

        super().__init__(str(main_file))  # Not sure what XarrayBase will think of non-strings.

        self.main_file = Path(main_file)
        self.uuid = os.path.basename(os.path.dirname(main_file))

        self._migrate_data()
        self.info.attrs['VERSION'] = version
        self._datasets: dict[str, xr.Dataset] = {}

        self._created_file = False
        if os.path.exists(main_file):
            self._load_datasets()
        else:
            self._created_file = True
            self.commit()

    def touch(self):
        """Ensure the data manager's main-file exists."""
        # This is here for compatibility with VisibleCoverageComponentBaseData, but we don't need to do anything
        # because the constructor runs first and already ensured the main-file exists.
        pass

    @property
    def _main_file_name(self):
        """What to name the data manager's main file."""
        return 'gmi_component.nc'

    def _migrate_data(self):
        """Method to migrate data from different versions of the file."""
        current_version = parse_version(self.info.attrs.get('VERSION', '0.0.0'))

        if current_version <= parse_version('0.2.15'):
            definitions = self.info.attrs.pop('GENERIC_MODEL', '')
            self.info.attrs['GENERIC_MODEL_TEMPLATE'] = definitions

    def _get_component_id(self) -> int:
        """Get a new component ID."""
        next_id = int(self.info.attrs.get('next_component_id', 1))
        self.info.attrs['next_component_id'] = next_id + 1
        return next_id

    @property
    def generic_model(self) -> GenericModel:
        """
        The generic model.

        - This only gets the model's template. Values are not included.
        - This is mainly useful for models with multiple or dynamic templates. Derived models should typically just have
          a function that makes their model template and always use that to get it.
        """
        definitions = self.info.attrs.get('GENERIC_MODEL_TEMPLATE', None)
        generic_model = GenericModel(definitions=definitions)
        return generic_model

    @generic_model.setter
    def generic_model(self, generic_model: GenericModel):
        """
        The generic model.

        This only sets the model's template. Values are not included.
        """
        self.info.attrs['GENERIC_MODEL_TEMPLATE'] = generic_model.to_template().to_string()

    def get_curve(self, curve_id: int, use_dates: bool) -> tuple[Sequence[float | np.datetime64], Sequence[float]]:
        """
        Get a curve.

        If the requested curve does not exist, a default one will be returned instead.
        Signature matches `CurveGetter`.

        Args:
            curve_id: The curve's ID.
            use_dates: Whether to retrieve dates or floats for the X values.

        Returns:
            A tuple of (X, Y) representing the requested curve.
            X is a list of `np.datetime64` if `use_dates` is True, else `float`.
            Y is a list of `float`.
        """
        curves = self._get_dataset(CURVE_DATASET)
        curve = curves.where(curves['xy_id'] == curve_id, drop=True)
        if len(curve['xy_id']) == 0:
            return [pd.to_datetime(0.0) if use_dates else 0.0], [0.0]

        # If we don't have the required data, return a default curve
        if 'dt' not in curve and use_dates:
            return [pd.to_datetime(0.0)], [0.0]
        elif 'x' not in curve and not use_dates:
            return [0.0], [0.0]

        x = curve['dt'].values if use_dates else curve['x'].values
        y = curve['y'].values
        return x, y

    def add_curve(self, x: Sequence[float | np.datetime64], y: Sequence[float], use_dates: bool) -> int:
        """
        Add a new curve.

        Signature matches `CurveAdder`.

        Args:
            x: X values for the curve. Should be of type `np.datetime64` if `use_dates` is True, else `float`.
            y: Y values for the curve.
            use_dates: Whether to use dates for the X values.

        Returns:
            The ID assigned to the new curve.
        """
        existing_curves = self._get_dataset(CURVE_DATASET)
        curve_id = 1
        if len(existing_curves['xy_id'].values) > 0:
            curve_id = int(max(existing_curves['xy_id'].values)) + 1
        new_curve = make_xy_dataset(curve_id, x, y, use_dates)
        combined = xr.concat([existing_curves, new_curve], 'xy_id')
        self._set_dataset(combined, CURVE_DATASET)
        return curve_id

    @property
    def _dataset_names(self) -> set[str]:
        """
        The names of datasets used by this data class.

        Derived classes can override it to add/remove names. If they add names, they should also override
        `self._create_dataset()`.
        """
        return {CURVE_DATASET}

    def _create_dataset(self, name: str) -> xr.Dataset:
        """
        Create an empty dataset, given its name.

        Derived classes should override this to handle any names they add to `self._dataset_names`. Any names they
        don't need to specifically handle should be handled by `return super()._create_dataset(name)`.

        Args:
            name: The name of the dataset to create.

        Returns:
            A new dataset with the appropriate structure for this name.
        """
        if name == CURVE_DATASET:
            return make_xy_dataset()
        else:
            # If this happens, it's likely a derived class added a new name to `self._dataset_names` and forgot to
            # override this function with one that handles the name.
            raise AssertionError(f'Unknown dataset name: {name}')

    def _get_dataset(self, name: str):
        """
        Get a dataset by its name.

        Args:
            name: name of the dataset to get.
        """
        if name in self._datasets:
            return self._datasets[name]

        with suppress(FileNotFoundError, OSError):
            self._datasets[name] = xr.load_dataset(self._filename, group=name)
            return self._datasets[name]

        self._datasets[name] = self._create_dataset(name)
        return self._datasets[name]

    def _set_dataset(self, dataset: xr.Dataset, name: str):
        """
        Set a dataset by its name.

        Args:
            dataset: The dataset to set.
            name: name of the dataset to overwrite.
        """
        self._datasets[name] = dataset

    def _load_datasets(self):
        """Load all the datasets."""
        for name in self._dataset_names:
            self._get_dataset(name)

    def commit(self):
        """Save current in-memory component data to main file."""
        # Always load all data into memory and vacuum the file. Data is small enough and prevents filesystem bloat.
        self._info.close()
        for _, d in self._datasets.items():
            d.close()
        removefile(self._filename)
        super().commit()  # Recreates the NetCDF file
        for k, d in self._datasets.items():
            d.to_netcdf(self._filename, group=k, mode='a')
