"""This module defines data for the save points hidden component."""

# 1. Standard Python modules
import os
from pathlib import Path

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.cmsflow.components.id_files import UNINITIALIZED_COMP_ID


def check_for_object_strings_dumb(dset, variables):
    """Need this stupid check because xarray.where() switches the dtype of empty string variables to object.

    object dtype fails when serializing to NetCDF.

    Args:
        dset (:obj:`xarray.Dataset`): The Dataset to check for bad string variables
        variables (:obj:`Iterable`): The names of the string variables that are potentially bad
    """
    for variable in variables:
        if variable in dset:
            dset[variable] = dset[variable].astype(str)


class SavePointsData(XarrayBase):
    """Manages data file for the hidden save points component."""
    def __init__(self, data_file: str | Path):
        """
        Initializes the data class.

        Args:
            data_file: The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.

        """
        data_file = str(data_file)
        self._filename = data_file
        self._info = None
        self._points = None
        self._general = None
        self.deleted_comp_ids = set()  # Set this to be a list of deleted component ids in the coverage before vacuuming
        # Create the default file before calling super because we have our own attributes to write.
        created = self._get_default_datasets(data_file)
        super().__init__(data_file)
        self.info.attrs.pop('proj_dir', None)
        if created:
            self.add_save_point_atts()
            self.commit()

    def _get_default_datasets(self, data_file):
        """Create default datasets if needed.

        Args:
            data_file (str): Name of the data file. If it doesn't exist, it will be created.

        Returns:
            Returns True if datasets were created.
        """
        if not os.path.exists(data_file) or not os.path.isfile(data_file):
            info = {
                'FILE_TYPE': 'CMSFLOW_SAVE',
                # 'VERSION': pkg_resources.get_distribution('cmsflow').version,
                'cov_uuid': '',
                'point_display_uuid': '',
                'next_comp_id': 0
            }
            self._info = xr.Dataset(attrs=info)

            point_table = {
                'name': ('comp_id', np.array([], dtype=object)),
                'hydro': ('comp_id', np.array([], dtype=np.int32)),
                'sediment': ('comp_id', np.array([], dtype=np.int32)),
                'salinity': ('comp_id', np.array([], dtype=np.int32)),
                'waves': ('comp_id', np.array([], dtype=np.int32)),
            }
            coords = {'comp_id': np.array([], dtype=int)}
            self._points = xr.Dataset(data_vars=point_table, coords=coords)

            general = {
                'HYDRO_VALUE': 0.5,
                'HYDRO_UNITS': 'hours',
                'SEDIMENT_VALUE': 0.5,
                'SEDIMENT_UNITS': 'hours',
                'SALINITY_VALUE': 0.5,
                'SALINITY_UNITS': 'hours',
                'WAVES_VALUE': 0.5,
                'WAVES_UNITS': 'hours',
            }
            self._general = xr.Dataset(attrs=general)
            self.commit()
            return True
        else:
            return False

    def commit(self):
        """Save current in-memory component parameters to data file.

        """
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._points is not None:
            check_for_object_strings_dumb(self._points, ['name'])
            self._points.close()
            self._drop_h5_groups(['points'])
            self._points.to_netcdf(self._filename, group='points', mode='a')
        if self._general is not None:
            self._general.close()
            self._general.to_netcdf(self._filename, group='general', mode='a')

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space.

        All BC datasets that need to be written to the file must be loaded into memory before calling this method.

        """
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._points is None:
            self._points = self.get_dataset('points', False)
        if self._general is None:
            self._general = self.get_dataset('general', False)
        try:
            os.remove(self._filename)
        except Exception:
            pass
        self.commit()  # Rewrite all datasets

    @property
    def general(self):
        """Load the general dataset from disk.

        Returns:
            xarray.Dataset: Dataset interface to the general datasets in the main file

        """
        if self._general is None:
            self._general = self.get_dataset('general', False)
        return self._general

    @general.setter
    def general(self, dset):
        """Setter for the general attribute."""
        if dset is not None:  # attrs only, change to if statement if data_vars added.
            self._general = dset

    @property
    def points(self):
        """Load the points dataset from disk.

        Returns:
            xarray.Dataset: Dataset interface to the points datasets in the main file

        """
        if self._points is None:
            self._points = self.get_dataset('points', False)
        return self._points

    @points.setter
    def points(self, dset):
        """Setter for the points attribute."""
        if dset:
            self._points = dset

    def update_save_point(self, comp_id, new_atts):
        """Update the save point attributes of a save point.

        Args:
            comp_id (int): Component id of the save point to update
            new_atts (xarray.Dataset): The new attributes for the save point

        """
        self.points['name'].loc[dict(comp_id=[comp_id])] = new_atts['name']
        self.points['hydro'].loc[dict(comp_id=[comp_id])] = int(new_atts['hydro'])
        self.points['sediment'].loc[dict(comp_id=[comp_id])] = int(new_atts['sediment'])
        self.points['salinity'].loc[dict(comp_id=[comp_id])] = int(new_atts['salinity'])
        self.points['waves'].loc[dict(comp_id=[comp_id])] = int(new_atts['waves'])

    def add_save_point_atts(self, dset=None):
        """Add the save point attribute dataset for a point.

        Args:
            dset (xarray.Dataset): The attribute dataset to concatenate. If not provided, a new Dataset of default
                attributes will be generated.


        Returns:
            (tuple of int): The newly generated component id

        """
        try:
            new_comp_id = self.info.attrs['next_comp_id'].item()
            self.info.attrs['next_comp_id'] += 1  # Increment the unique XMS component id.
            if dset is None:  # Generate a new default Dataset
                dset = self._get_new_point_atts(new_comp_id)
            else:  # Update the component id of an existing Dataset
                dset.coords['comp_id'] = [new_comp_id for _ in dset.coords['comp_id']]
            self._points = xr.concat([self.points, dset], 'comp_id')
            return new_comp_id
        except Exception:
            return UNINITIALIZED_COMP_ID

    def concat_save_points(self, save_pts_data):
        """Adds the save points attributes from save_pts_data to this instance of SavePointsData.

        Args:
            save_pts_data (SavePointsData): another SavePointsData instance

        Returns:
            dict: The old ids of the save_pts_data as key and the new ids as the data

        """
        next_comp_id = self.info.attrs['next_comp_id']
        # Reassign component id coordinates.
        new_save_points = save_pts_data.points
        num_concat_points = new_save_points.sizes['comp_id']
        if num_concat_points:
            old_comp_ids = new_save_points.coords['comp_id'].data.astype('i4').tolist()
            new_save_points.coords['comp_id'] = [next_comp_id + idx for idx in range(num_concat_points)]
            self.info.attrs['next_comp_id'] = next_comp_id + num_concat_points
            self._points = xr.concat([self.points, new_save_points], 'comp_id')
            return {
                old_comp_id: new_comp_id
                for old_comp_id, new_comp_id in
                zip(old_comp_ids, new_save_points.coords['comp_id'].data.astype('i4').tolist())
            }
        else:
            return {}

    @staticmethod
    def _get_new_point_atts(comp_id):
        """Get a new dataset with default attributes for a save point.

        Args:
            comp_id (int): The unique XMS component id of the save point. If UNINITIALIZED_COMP_ID, a new one is
                generated.

        Returns:
            (xarray.Dataset): A new default dataset for a BC arc. Can later be concatenated to persistent dataset.

        """
        point_table = {
            'name': ('comp_id', np.array([''], dtype=object)),
            'hydro': ('comp_id', np.array([0], dtype=np.int32)),
            'sediment': ('comp_id', np.array([0], dtype=np.int32)),
            'salinity': ('comp_id', np.array([0], dtype=np.int32)),
            'waves': ('comp_id', np.array([0], dtype=np.int32)),
        }
        coords = {'comp_id': [comp_id]}
        ds = xr.Dataset(data_vars=point_table, coords=coords)
        return ds
