"""ObsTargetData class."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import datetime
import os

# 2. Third party modules
import numpy as np
import pkg_resources
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem import filesystem
from xms.guipy.time_format import ISO_DATETIME_FORMAT

# 4. Local modules


def check_for_object_strings_dumb(dset, variables):
    """Need this stupid check because xarray.where() switches the dtype of empty string variables to object.

    object dtype fails when serializing to NetCDF.

    Args:
        dset (xr.Dataset): The Dataset to check for bad string variables
        variables (Iterable): The names of the string variables that are potentially bad
    """
    for variable in variables:
        if variable in dset and dset[variable].dtype == object:
            dset[variable] = dset[variable].astype(np.unicode_)


class ObsTargetData(XarrayBase):
    """Class for storing the Generic Coverages Observation Targets properties."""
    OBS_CATEGORY_POINT = 0  # Display list category ID for observation target points
    OBS_CATEGORY_ARC = 1  # Display list category ID for observation target arcs
    OBS_CATEGORY_ARC_GROUP = 2  # Display list category ID for observation target arc groups
    OBS_CATEGORY_POLY = 3  # Display list category ID for observation target polygons
    UNINITIALIZED_COMP_ID = -1  # Map coverage component ID for observation target points that have no data

    def __init__(self, filename):
        """Constructor.

        Args:
            filename (str): file name
        """
        self._info = None
        self._targets = None
        super().__init__(filename)
        self._get_default_data()

    def _get_default_data(self):
        """Create default datasets if this is the first initialization of the data."""
        if not os.path.exists(self._filename) or not os.path.isfile(self._filename):
            self._default_info()

    def _default_info(self):
        """Initializes the info Dataset."""
        self._info.attrs['FILE_TYPE'] = 'GENERIC_COVERGES_OBS_DATA'
        self._info.attrs['VERSION'] = pkg_resources.get_distribution('xmscoverage').version
        self._info.attrs['cov_uuid'] = ''
        self._info.attrs[f'display_uuid_{ObsTargetData.OBS_CATEGORY_POINT}'] = ''
        self._info.attrs[f'display_uuid_{ObsTargetData.OBS_CATEGORY_ARC}'] = ''
        self._info.attrs[f'display_uuid_{ObsTargetData.OBS_CATEGORY_ARC_GROUP}'] = ''
        self._info.attrs[f'display_uuid_{ObsTargetData.OBS_CATEGORY_POLY}'] = ''
        self._info.attrs['feature_type'] = ObsTargetData.UNINITIALIZED_COMP_ID  # All feature types allowed by default
        self._info.attrs['reftime'] = '1950-01-01 00:00:00'  # strftime/strptime format is '%Y-%m-%d %H:%M:%S'
        self._info.attrs['dset_uuid'] = ''  # Uuid of the dset associated with the observations

    def _default_targets(self):
        """Initializes the targets dataset."""
        targets = {
            'feature_type': ('comp_id', np.array([], dtype=np.int32)),
            'category': ('comp_id', np.array([], dtype=np.int32)),  # Matches CategoryDisplayOption.id in JSON file
            'name': ('comp_id', np.array([], dtype=object)),
            'time': ('comp_id', np.array([], dtype="datetime64[ns]")),
            'interval': ('comp_id', np.array([], dtype=np.float64)),
            'observed': ('comp_id', np.array([], dtype=np.float64)),
            'computed': ('comp_id', np.array([], dtype=np.float64)),
        }
        coords = {'comp_id': []}
        self._targets = xr.Dataset(data_vars=targets, coords=coords)

    @property
    def reftime(self):
        """Get the observation reference datetime.

        Notes:
            The times of the observations in the targets Dataset must already be absolute datetimes. This reference
            date will be used by XMS when rendering observation whisker plots to find the closest observations to the
            current solution dataset timestep if the dataset does not have a reference datetime defined. Only applicable
            if the observations are transient.

        Returns:
            datetime.datetime: See description
        """
        return datetime.datetime.strptime(self.info.attrs['reftime'], ISO_DATETIME_FORMAT)

    @reftime.setter
    def reftime(self, obs_reftime):
        """Set the observation reference datetime.

        Args:
            obs_reftime (datetime.datetime): Reference datetime of the observations
        """
        self.info.attrs['reftime'] = obs_reftime.strftime(ISO_DATETIME_FORMAT)

    @property
    def targets(self):
        """Get the targets component data parameters.

        Returns:
            xarray.Dataset: The targets list dataset
        """
        if self._targets is None:
            self._targets = self.get_dataset('targets', False)
            if self._targets is None:
                self._default_targets()
        return self._targets

    @targets.setter
    def targets(self, dset):
        """Setter for the targets dataset."""
        if dset:
            self._targets = dset

    def commit(self):
        """Save in memory datasets to the NetCDF file."""
        # Ensure all data is loaded from disk, and always vacuum the file. Targets dataset is small enough.
        _ = self.info
        _ = self.targets
        self._info.close()
        self._targets.close()
        filesystem.removefile(self._filename)
        super().commit()  # Recreates the file
        check_for_object_strings_dumb(self._targets, ['name'])
        self._targets.to_netcdf(self._filename, group='targets', mode='a')
