"""CoverageData class."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"

# 1. Standard Python modules
from pathlib import Path
from typing import Callable, Iterable, Optional, Sequence

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.visible_coverage_component_base_data import VisibleCoverageComponentBaseData
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.gmi.data.component_data import COMP_ID_KEY, ComponentData, TYPE_KEY, VALUE_KEY
from xms.gmi.data.generic_model import GenericModel, GmiName

_POINT_DATASET = 'point_data'  # Name of the point dataset in `CoverageData`.
_ARC_DATASET = 'arc_data'  # Name of the arc dataset in `CoverageData`.
_POLYGON_DATASET = 'polygon_data'  # Name of the polygon dataset in `CoverageData`.


class CoverageData(ComponentData, VisibleCoverageComponentBaseData):
    """
    Manages data file for a feature coverage component.

    `CoverageData` derives from `ComponentData` to add more functionality:

    - Add and update feature values and types
    - Find features by feature type or `TargetType`
    - Slightly simplified access to `self.generic_model.*_parameters.group_names`
    """
    def __init__(self, main_file: Optional[str | Path] = None, generic_model: Optional[GenericModel] = None):
        """
        Initialize the data class.

        Args:
            main_file: The netcdf file (with path) associated with this instance data. Probably the owning
                       component's main file. If not provided, a default one will be chosen.
            generic_model: The model this coverage should use.
        """
        super().__init__(main_file)

        self.info.attrs['FILE_TYPE'] = 'GMI_COV'

        if generic_model is not None:
            self.generic_model = generic_model

        if self._created_file:
            self.commit()

    @property
    def coverage_uuid(self):
        """The coverage's UUID."""
        return self.info.attrs.get('coverage_uuid', '')

    @coverage_uuid.setter
    def coverage_uuid(self, value):
        """The coverage's UUID."""
        self.info.attrs['coverage_uuid'] = value

    def drop_unused_features(self, target: TargetType, used_ids: list[int]):
        """
        Drop any unused features.

        Args:
            target: The type of feature to drop.
            used_ids: Component IDs of features that are still in use and
                should be kept. IDs not in this list are dropped.
        """
        mask = self._get_dataset(target).comp_id.isin(used_ids)
        unmasked = self._get_dataset(target)
        masked = unmasked.where(mask, drop=True)
        self._set_dataset(masked, target)

    def add_features(self, target: TargetType, values: list[str], active_groups: list[GmiName]) -> Iterable[int]:
        """
        Add multiple new features at once.

        Args:
            target: The type of the features.
            values: List of values for the features, one element per feature.
            active_groups: List of active groups for the features, parallel to values parameter.

        Returns:
            Component IDs assigned to the new features, parallel to `values`
            and `active_groups`.
        """
        existing_features = self._get_dataset(target)
        new_features = make_feature_dataset(self._get_component_id, values, active_groups)
        new_component_ids = new_features[COMP_ID_KEY].values.astype(int)

        # The function that makes new feature datasets uses `object` as the dtype. If we have a populated dataset then
        # xarray will correctly guess that they're strings, but for empty ones it will guess float. If it guesses float
        # the first time, then string the second time, it will try to parse the new dataset's values as floats and
        # crash.
        #
        # It would be best to explicitly declare the variables as strings, but I tried that a while back and it failed
        # for some reason I can't remember anymore. Also, it might break existing files which have already been
        # released. Rather than risk that, we'll just throw out the old dataset if it's empty. The new one will be
        # guessed with the right types, and the old one had nothing of interest in it.
        #
        # This answer and comments below it suggest xarray stores variable-length strings as `object` by default, or
        # fixed-length strings if you request `str` instead, because it's built on numpy which doesn't support anything
        # else. Documentation for numpy suggests they finally added variable-length strings in version 2.0, but we don't
        # support that yet since it breaks a bunch of stuff. That might be why it was broken.
        # https://stackoverflow.com/a/49055035
        arbitrary_coordinate_name = next(iter(existing_features.coords))
        if existing_features.sizes[arbitrary_coordinate_name] > 0:
            combined = xr.concat([existing_features, new_features], COMP_ID_KEY)
        else:
            combined = new_features

        self._set_dataset(combined, target)
        return new_component_ids

    def add_feature(self, target: TargetType, values: str, active_group: GmiName) -> int:
        """
        Add a new feature.

        Args:
            target: The type of feature.
            values: Values to associate with the feature.
            active_group: The feature's currently active group.

        Returns:
            The component ID assigned to the new feature.
        """
        component_ids = self.add_features(target, [values], [active_group])
        return int(next(iter(component_ids)))

    def feature_type_values(self, target: TargetType, component_id: int) -> tuple[str, str]:
        """
        Get the type and values associated with a feature.

        Args:
            target: The type of feature to get values for.
            component_id: The ID of the feature to get values for.

        Returns:
            The type and values associated with the feature, or empty strings if there is no data.
        """
        dataset = self._get_dataset(target)
        feature_dataset = dataset.where(dataset[COMP_ID_KEY] == component_id, drop=True)
        try:
            return feature_dataset[TYPE_KEY].values[0], feature_dataset[VALUE_KEY].values[0]
        except IndexError:
            return '', ''

    def feature_values(self, target: TargetType, component_id: int) -> str:
        """
        Get the values associated with a feature.

        Args:
            target: The type of feature to get values for.
            component_id: The ID of the feature to get values for.

        Returns:
            The values associated with the feature, or an empty string if there are no associated values.
        """
        return self.feature_type_values(target, component_id)[1]

    def feature_type(self, target: TargetType, component_id: int) -> GmiName:
        """
        Get the type associated with a feature.

        Args:
            target: The type of feature to get values for.
            component_id: The ID of the feature to get values for.

        Returns:
            The type associated with the feature, or an empty string if there is no associated type.
        """
        return self.feature_type_values(target, component_id)[0]

    def update_feature_type_values(
        self, target: TargetType, component_id: int, feature_type: GmiName, values: str
    ) -> bool:
        """
        Update the type and values for a feature.

        If the feature is not present in the data, no changes are made.

        Args:
            target: The type of feature to get values for.
            component_id: The ID of the feature to get values for.
            feature_type: type of the feature.
            values: values for the feature.

        Returns:
            Whether the component ID existed.
        """
        dataset = self._get_dataset(target)
        component_ids = list(dataset[COMP_ID_KEY].values)
        try:
            idx = component_ids.index(component_id)
        except ValueError:
            return False
        dataset[TYPE_KEY].values[idx] = feature_type
        dataset[VALUE_KEY].values[idx] = values
        return True

    def component_ids(self, target: TargetType) -> Iterable[int]:
        """
        Get all the component IDs for a given target.

        Args:
            target: What type of elements to look for.

        Returns:
            An iterable of component IDs.
        """
        dataset = self._get_dataset(target)
        return dataset[COMP_ID_KEY].data.astype(int)

    def component_ids_with_type(self, target: TargetType, feature_type: GmiName) -> Sequence[int]:
        """
        Get the component IDs that have a given feature type.

        Args:
            target: What type of elements to look for.
            feature_type: The type to search for.

        Returns:
            An iterable of component IDs which have the given type.
        """
        dataset = self._get_dataset(target)
        values = dataset.where(dataset.type == feature_type, drop=True)
        return values[COMP_ID_KEY].data.astype(int)

    def arc_types(self) -> Sequence[GmiName]:
        """
        Get a list of the types of arcs in this coverage.

        Equivalent to the names of arc groups in `self.generic_model`.

        Returns:
            The types of arcs.
        """
        return self.generic_model.arc_parameters.group_names

    def point_types(self) -> Sequence[GmiName]:
        """
        Get a list of the types of points in this coverage.

        Equivalent to the names of point groups in `self.generic_model`.

        Returns:
            The types of points.
        """
        return self.generic_model.point_parameters.group_names

    def poly_types(self) -> Sequence[GmiName]:
        """
        Get a list of the types of polygons in this coverage.

        Equivalent to the names of polygon groups in `self.generic_model`.

        Returns:
            The types of polygons.
        """
        return self.generic_model.polygon_parameters.group_names

    @property
    def _main_file_name(self):
        """What to name the component's main file."""
        return 'gmi_coverage.nc'

    @property
    def _dataset_names(self) -> set[str]:
        """
        The names of datasets used by this data class.

        Derived classes can override it to add/remove names. If they add names, they should also override
        `self._create_dataset()`.
        """
        return super()._dataset_names | {_POINT_DATASET, _ARC_DATASET, _POLYGON_DATASET}

    def _create_dataset(self, name: str) -> xr.Dataset:
        """
        Create an empty dataset, given its name.

        Derived classes should override this to handle any names they add to `self._dataset_names`. Any names they
        don't need to specifically handle should be handled by `return super()._create_dataset(name)`.

        Args:
            name: The name of the dataset to create.

        Returns:
            A new dataset with the appropriate structure for this name.
        """
        if name in {_POINT_DATASET, _ARC_DATASET, _POLYGON_DATASET}:
            return make_feature_dataset()

        return super()._create_dataset(name)

    def _get_dataset(self, target: TargetType | str):
        """
        Get a dataset by its target type.

        Args:
            target: Target to get dataset for.
        """
        name = target_to_name(target)
        return super()._get_dataset(name)

    def _set_dataset(self, dataset: xr.Dataset, target: TargetType | str):
        """
        Set a dataset by its target type.

        Args:
            dataset: dataset to set.
            target: Target to set dataset for.
        """
        name = target_to_name(target)
        super()._set_dataset(dataset, name)


def make_feature_dataset(
    get_comp_id: Optional[Callable[[], int]] = None,
    values: Optional[list] = None,
    active_groups: Optional[list] = None
):
    """
    Make a new feature dataset.

    An empty default dataset can be created by not passing any parameters.

    Args:
        get_comp_id: A callable that will return a new component ID. Will not be called if values and types are empty.
        values: Values for the features, one feature per element.
        active_groups: Active groups for the features. Parallel to values.
    """
    # Get lists of everything we need
    active_groups = active_groups or []
    values = values if values is not None else []
    if get_comp_id:
        comp_ids = [get_comp_id() for _ in range(len(active_groups))]
    else:
        comp_ids = []

    # COMP_ID_KEY is our lookup key. It's an array of type int, containing the values in comp_ids.
    coords = {COMP_ID_KEY: np.array(comp_ids, dtype=int)}
    data_vars = {
        # The values of these are looked up by COMP_ID_KEY. They're both arrays of objects, containing  values
        # from types and values, respectively.
        TYPE_KEY: (COMP_ID_KEY, np.array(active_groups, dtype=object)),
        VALUE_KEY: (COMP_ID_KEY, np.array(values, dtype=object)),
    }
    return xr.Dataset(coords=coords, data_vars=data_vars)


def target_to_name(target: TargetType | str) -> str:
    """
    Convert a target to a name.

    If the target cannot be converted, it will be returned as-is, so this
    method can be called with names that don't need conversion too.

    Args:
        target: Target to convert.

    Returns:
        Name for the target.
    """
    mapping = {
        TargetType.point: _POINT_DATASET,
        TargetType.arc: _ARC_DATASET,
        TargetType.polygon: _POLYGON_DATASET,
    }
    if target in mapping:
        return mapping[target]
    return target
