"""CoverageData class."""

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"
__all__ = ['table_to_text_file', 'CoverageBaseData']

# 1. Standard Python modules
from pathlib import Path
from typing import Callable, Iterable, Optional, Sequence, TextIO

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.visible_coverage_component_base_data import VisibleCoverageComponentBaseData
from xms.guipy.data.target_type import TargetType

# 4. Local modules
from xms.gmi.data.generic_model import GenericModel, GmiName, Section
from xms.gmi.data_bases.base_data import BaseData

COMP_ID_KEY = 'comp_id'  #: Name of the component ID `DataArray` in a `ComponentData` dataset.
TYPE_KEY = 'type'  #: Name of the feature type `DataArray` in a `ComponentData` dataset.
VALUE_KEY = 'data_dict'  #: Name of the values `DataArray` in a `ComponentData` dataset.
_POINT_DATASET = 'point_data'  # Name of the point dataset in `CoverageData`.
_ARC_DATASET = 'arc_data'  # Name of the arc dataset in `CoverageData`.
_POLYGON_DATASET = 'polygon_data'  # Name of the polygon dataset in `CoverageData`.


class CoverageBaseData(BaseData, VisibleCoverageComponentBaseData):
    """
    Manages data file for a feature coverage component.

    `CoverageData` derives from `ComponentData` to add more functionality:

    - Add and update feature values and types
    - Find features by feature type or `TargetType`
    - Slightly simplified access to `self.generic_model.*_parameters.group_names`
    """
    def __init__(self, main_file: str | Path):
        """
        Initialize the data class.

        Args:
            main_file: The netcdf file (with path) associated with this instance data. Probably the owning
                       component's main file.
        """
        super().__init__(main_file)

        self.info.attrs['FILE_TYPE'] = 'GMI_COVERAGE_BASE'

        self.commit()

    @property
    def coverage_uuid(self):
        """The coverage's UUID."""
        return self.info.attrs.get('coverage_uuid', '')

    @coverage_uuid.setter
    def coverage_uuid(self, value):
        """The coverage's UUID."""
        self.info.attrs['coverage_uuid'] = value

    def drop_unused_features(self, target: TargetType, used_ids: list[int]):
        """
        Drop any unused features.

        Args:
            target: The type of feature to drop.
            used_ids: Component IDs of features that are still in use and
                should be kept. IDs not in this list are dropped.
        """
        mask = self._get_dataset(target).comp_id.isin(used_ids)
        unmasked = self._get_dataset(target)
        masked = unmasked.where(mask, drop=True)
        self._set_dataset(masked, target)

    def add_features(self, target: TargetType, values: list[str], active_groups: list[GmiName]) -> Iterable[int]:
        """
        Add multiple new features at once.

        Args:
            target: The type of the features.
            values: List of values for the features, one element per feature.
            active_groups: List of active groups for the features, parallel to values parameter.

        Returns:
            Component IDs assigned to the new features, parallel to `values`
            and `active_groups`.
        """
        # We use the `object` dtype for the values to get variable-length strings, but xarray has a tendency to impose
        # a maximum length on the string, which is usually approximately the size of the longest one in the dataset. If
        # we read a dataset with a short maximum length and concatenate it with one where the strings exceed that
        # length, then the new dataset's strings get truncated. This comment suggested dropping the encoding of the
        # dataset to fix the issue, which seems to work:
        # https://github.com/pydata/xarray/issues/9037#issuecomment-2121196387
        existing_features = self._get_dataset(target).drop_encoding()
        new_features = make_feature_dataset(self._get_component_id, values, active_groups)
        new_component_ids = new_features[COMP_ID_KEY].values.astype(int)

        combined = xr.concat([existing_features, new_features], COMP_ID_KEY)

        self._set_dataset(combined, target)
        return new_component_ids

    def add_feature(self, target: TargetType, values: str, active_group: GmiName) -> int:
        """
        Add a new feature.

        Args:
            target: The type of feature.
            values: Values to associate with the feature.
            active_group: The feature's currently active group.

        Returns:
            The component ID assigned to the new feature.
        """
        component_ids = self.add_features(target, [values], [active_group])
        return int(next(iter(component_ids)))

    def feature_type_values(self, target: TargetType, component_id: int = 0, feature_id: int = 0) -> tuple[str, str]:
        """
        Get the type and values associated with a feature.

        Exactly one of feature_id and component_id must be passed.

        Args:
            target: The type of feature to get values for.
            component_id: The component ID of the feature to get values for.
            feature_id: The feature ID of the feature to get values for. If passed, the data manager's component_id_map
                must have been initialized, typically by passing the component that holds this data manager to
                Query.load_component_ids.

        Returns:
            The type and values associated with the feature, or empty strings if there is no data.
        """
        assert (component_id == 0) != (feature_id == 0)
        if component_id == 0:
            component_id = self.component_id_map[target].get(feature_id, -1)

        dataset = self._get_dataset(target)

        try:
            feature_dataset = dataset.sel(comp_id=component_id)
            return feature_dataset[TYPE_KEY].item(), feature_dataset[VALUE_KEY].item()
        except (IndexError, KeyError):
            return '', ''

    def feature_values(self, target: TargetType, component_id: int = 0, feature_id: int = 0) -> str:
        """
        Get the values associated with a feature.

        Exactly one of feature_id and component_id must be passed.

        Args:
            target: The type of feature to get values for.
            component_id: The component ID of the feature to get values for.
            feature_id: The feature ID of the feature to get values for. If passed, the data manager's component_id_map
                must have been initialized, typically by passing the component that holds this data manager to
                Query.load_component_ids.

        Returns:
            The values associated with the feature, or an empty string if there are no associated values.
        """
        return self.feature_type_values(target, component_id, feature_id)[1]

    def feature_type(self, target: TargetType, component_id: int = 0, feature_id: int = 0) -> GmiName:
        """
        Get the type associated with a feature.

        Exactly one of feature_id and component_id must be passed.

        Args:
            target: The type of feature to get values for.
            component_id: The component ID of the feature to get values for.
            feature_id: The feature ID of the feature to get values for. If passed, the data manager's component_id_map
                must have been initialized, typically by passing the component that holds this data manager to
                Query.load_component_ids.

        Returns:
            The type associated with the feature, or an empty string if there is no associated type.
        """
        return self.feature_type_values(target, component_id, feature_id)[0]

    def component_ids(self, target: TargetType) -> Iterable[int]:
        """
        Get all the component IDs for a given target.

        Args:
            target: What type of elements to look for.

        Returns:
            An iterable of component IDs.
        """
        dataset = self._get_dataset(target)
        return dataset[COMP_ID_KEY].data.astype(int)

    def component_ids_with_type(self, target: TargetType, feature_type: GmiName) -> Sequence[int]:
        """
        Get the component IDs that have a given feature type.

        Args:
            target: What type of elements to look for.
            feature_type: The type to search for.

        Returns:
            An iterable of component IDs which have the given type.
        """
        dataset = self._get_dataset(target)
        values = dataset.where(dataset.type == feature_type, drop=True)
        return values[COMP_ID_KEY].data.astype(int)

    @property
    def _dataset_names(self) -> set[str]:
        """
        The names of datasets used by this data class.

        Derived classes can override it to add/remove names. If they add names, they should also override
        `self._create_dataset()`.
        """
        return super()._dataset_names | {_POINT_DATASET, _ARC_DATASET, _POLYGON_DATASET}

    def _create_dataset(self, name: str) -> xr.Dataset:
        """
        Create an empty dataset, given its name.

        Derived classes should override this to handle any names they add to `self._dataset_names`. Any names they
        don't need to specifically handle should be handled by `return super()._create_dataset(name)`.

        Args:
            name: The name of the dataset to create.

        Returns:
            A new dataset with the appropriate structure for this name.
        """
        if name in {_POINT_DATASET, _ARC_DATASET, _POLYGON_DATASET}:
            return make_feature_dataset()

        return super()._create_dataset(name)

    def _get_dataset(self, target: TargetType | str):
        """
        Get a dataset by its target type.

        Args:
            target: Target to get dataset for.
        """
        name = target_to_name(target)
        return super()._get_dataset(name)

    def _set_dataset(self, dataset: xr.Dataset, target: TargetType | str):
        """
        Set a dataset by its target type.

        Args:
            dataset: dataset to set.
            target: Target to set dataset for.
        """
        name = target_to_name(target)
        super()._set_dataset(dataset, name)

    def to_text_file(self, model: GenericModel, file: TextIO):
        """
        Export the data manager's content to a text file in a format suitable for comparing to test baselines.

        Args:
            model: Model describing the data in the manager.
            file: Where to export to.
        """
        self._mapping_to_text_file(file)

        for target in [TargetType.point, TargetType.arc, TargetType.polygon]:
            file.write(f'{target.name}\n')
            section = model.section_from_target_type(target)
            self._target_to_text_file(section, target, file)

    def _mapping_to_text_file(self, file: TextIO):
        """Write the component ID map to a text file."""
        file.write('mapping\n')
        if not self.component_id_map:
            file.write('[]\n\n')
            return

        table = [('target', 'feature_id', 'component_id')]

        for target in [TargetType.point, TargetType.arc, TargetType.polygon]:
            for feature_id, component_id in self.component_id_map.get(target, {}).items():
                table.append((target.name, feature_id, component_id))
        table_to_text_file(table, file)

    def _target_to_text_file(self, section: Section, target: TargetType, file: TextIO):
        """
        Write all the component data for a given target type to a text file.

        Args:
            section: Groups and parameters to write.
            target: Target to write.
            file: Where to write to.
        """
        header, order = _text_file_header_and_order(section)
        table = [header]
        for component_id in self.component_ids(target):
            type_, values = self.feature_type_values(target, component_id)
            section.restore_values(values)
            active_groups = tuple(sorted(section.active_group_names))
            row = [component_id, type_, active_groups]
            for group_name, parameter_name in order:
                value = section.group(group_name).parameter(parameter_name).value
                row.append(value)
            table.append(tuple(row))
        table_to_text_file(table, file)


def make_feature_dataset(
    get_comp_id: Optional[Callable[[], int]] = None,
    values: Optional[list] = None,
    active_groups: Optional[list] = None
):
    """
    Make a new feature dataset.

    An empty default dataset can be created by not passing any parameters.

    Args:
        get_comp_id: A callable that will return a new component ID. Will not be called if values and types are empty.
        values: Values for the features, one feature per element.
        active_groups: Active groups for the features. Parallel to values.
    """
    # Get lists of everything we need
    active_groups = active_groups or []
    values = values if values is not None else []
    if get_comp_id:
        comp_ids = [get_comp_id() for _ in range(len(active_groups))]
    else:
        comp_ids = []

    # COMP_ID_KEY is our lookup key. It's an array of type int, containing the values in comp_ids.
    coords = {COMP_ID_KEY: np.array(comp_ids, dtype=int)}
    data_vars = {
        # The values of these are looked up by COMP_ID_KEY. They're both arrays of objects, containing  values
        # from types and values, respectively.
        TYPE_KEY: (COMP_ID_KEY, np.array(active_groups, dtype=object)),
        VALUE_KEY: (COMP_ID_KEY, np.array(values, dtype=object)),
    }
    return xr.Dataset(coords=coords, data_vars=data_vars)


def target_to_name(target: TargetType | str) -> str:
    """
    Convert a target to a name.

    If the target cannot be converted, it will be returned as-is, so this
    method can be called with names that don't need conversion too.

    Args:
        target: Target to convert.

    Returns:
        Name for the target.
    """
    mapping = {
        TargetType.point: _POINT_DATASET,
        TargetType.arc: _ARC_DATASET,
        TargetType.polygon: _POLYGON_DATASET,
    }
    if target in mapping:
        return mapping[target]
    return target


def _text_file_header_and_order(section: Section) -> tuple[tuple[str, ...], list[tuple[str, str]]]:
    """
    Get a header and order for the items in a section.

    The header will have each item sorted so that the order is stable for testing.

    Args:
        section: The section to get header and order for.

    Returns:
        Tuple of header, order. The header is a tuple of strings identifying the columns. The order is a list of tuples,
        where each element contains a group name and a parameter name. The header includes additional id, type, and
        group columns that are excluded from the order.
    """
    header = ['id', 'declared_type', 'active_groups']
    order = []
    # This is meant for testing, so we sort names here to ensure the output comes in a predictable order regardless of
    # how the model decides to generate its names.
    for group_name in sorted(section.group_names):
        group = section.group(group_name)
        for parameter_name in sorted(group.parameter_names):
            header.append(f'{group_name}/{parameter_name}')
            order.append((group_name, parameter_name))
    return tuple(header), order


def table_to_text_file(table: list[tuple], file: TextIO):
    """
    Write a list of tuples to a text file.

    Whitespace is added so that columns line up nicely.

    Args:
        table: List of tuples to write.
        file: Where to write to.
    """
    if len(table) == 1:
        file.write('[]\n\n')
        return

    sizes = [0] * len(table[0])
    for row in table:
        for index, cell in enumerate(row):
            sizes[index] = max(sizes[index], len(repr(cell)))
    for i in range(len(sizes)):
        sizes[i] += 2  # Add room for the comma and at least one space

    file.write('[\n')
    for row in table:
        file.write('(')
        for index, cell in enumerate(row):
            value = repr(cell) + ', '
            size = sizes[index]
            file.write(f'{value:<{size}}')
        file.write('),\n')
    file.write(']\n\n')
