"""BaseData class."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
from contextlib import suppress
import os
from pathlib import Path

# 2. Third party modules
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem.filesystem import removefile

# 4. Local modules

CURRENT_VERSION = 1


class BaseData(XarrayBase):
    """
    Data class for components.

    Provides very basic functionality for GMI-based data managers. Unlikely to be useful alone. See derived ones.
    """
    def __init__(self, main_file: str | Path):
        """
        Initialize the class.

        Args:
            main_file: Path to the component's data file. If not provided, a default one will be chosen.
        """
        super().__init__(str(main_file))  # Not sure what XarrayBase will think of non-strings.

        self.main_file = Path(main_file)
        self.uuid = os.path.basename(os.path.dirname(main_file))

        self.info.attrs['GMI_VERSION'] = CURRENT_VERSION
        self._datasets: dict[str, xr.Dataset] = {}

        if os.path.exists(main_file):
            self._load_datasets()
        else:
            self.commit()

    def touch(self):
        """Ensure the data manager's main-file exists."""
        # This is here for compatibility with VisibleCoverageComponentBaseData, but we don't need to do anything
        # because the constructor runs first and already ensured the main-file exists.
        pass

    def _get_component_id(self) -> int:
        """Get a new component ID."""
        next_id = int(self.info.attrs.get('next_component_id', 1))
        self.info.attrs['next_component_id'] = next_id + 1
        return next_id

    @property
    def _dataset_names(self) -> set[str]:
        """
        The names of datasets used by this data class.

        Derived classes can override it to add/remove names. If they add names, they should also override
        `self._create_dataset()`.
        """
        return set()

    def _create_dataset(self, name: str) -> xr.Dataset:
        """
        Create an empty dataset, given its name.

        Derived classes should override this to handle any names they add to `self._dataset_names`. Any names they
        don't need to specifically handle should be handled by `return super()._create_dataset(name)`.

        Args:
            name: The name of the dataset to create.

        Returns:
            A new dataset with the appropriate structure for this name.
        """
        # If this happens, it's likely a derived class added a new name to `self._dataset_names` and forgot to
        # override this function with one that handles the name.
        raise AssertionError(f'Unknown dataset name: {name}')  # pragma: nocover

    def _get_dataset(self, name: str):
        """
        Get a dataset by its name.

        Args:
            name: name of the dataset to get.
        """
        if name in self._datasets:
            return self._datasets[name]

        with suppress(FileNotFoundError, OSError):
            self._datasets[name] = xr.load_dataset(self._filename, group=name)
            return self._datasets[name]

        self._datasets[name] = self._create_dataset(name)
        return self._datasets[name]

    def _set_dataset(self, dataset: xr.Dataset, name: str):
        """
        Set a dataset by its name.

        Args:
            dataset: The dataset to set.
            name: name of the dataset to overwrite.
        """
        self._datasets[name] = dataset

    def _load_datasets(self):
        """Load all the datasets."""
        for name in self._dataset_names:
            self._get_dataset(name)

    def commit(self):
        """Save current in-memory component data to main file."""
        # Always load all data into memory and vacuum the file. Data is small enough and prevents filesystem bloat.
        self._info.close()
        for _, d in self._datasets.items():
            d.close()
        removefile(self._filename)
        super().commit()  # Recreates the NetCDF file
        for k, d in self._datasets.items():
            d.to_netcdf(self._filename, group=k, mode='a')
