"""This module defines data for the save points hidden component."""

# 1. Standard Python modules
import datetime
import os

# 2. Third party modules
import numpy as np
import pandas as pd
import xarray as xr


# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem import filesystem as io_util

# 4. Local modules
from xms.gencade.components.id_files import UNINITIALIZED_COMP_ID

POINT_MAIN_FILE = 'point_comp.nc'
POINT_TYPES = ['Attribute Modification', 'Wave Gage', 'Tidal Currents']


class PointsData(XarrayBase):
    """Manages data file for the hidden points component."""

    def __init__(self, data_file):
        """Initializes the data class.

        Args:
            data_file (:obj:`str`): The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.
        """
        self._filename = data_file
        self._info = None
        self._points = None
        self._wave_table = dict()
        self._tide_table = dict()
        self.deleted_comp_ids = set()  # Set this to be a list of deleted component ids in the coverage before vacuuming
        # Create the default file before calling super because we have our own attributes to write.
        created = self._get_default_datasets(data_file)
        super().__init__(data_file)
        if created:
            self.add_point_atts()
            self.commit()

    def _get_default_datasets(self, data_file):
        """Create default datasets if needed.

        Args:
            data_file (:obj:`str`): Name of the data file. If it doesn't exist, it will be created.

        Returns:
            (:obj:`bool`): Returns True if datasets were created.
        """
        if not os.path.exists(data_file) or not os.path.isfile(data_file):
            info = {
                'FILE_TYPE': 'GENCADE_POINTS',
                # 'VERSION': pkg_resources.get_distribution('xmsgencade').version,
                'cov_uuid': '',
                'next_comp_id': 0,
                'point_display_uuid': '',
            }
            self._info = xr.Dataset(attrs=info)

            point_table = {
                'set_k1': ('comp_id', np.array([], dtype=int)),
                'k1_val': ('comp_id', np.array([], dtype=float)),
                'set_k2': ('comp_id', np.array([], dtype=int)),
                'k2_val': ('comp_id', np.array([], dtype=float)),
                'set_ang_amplify': ('comp_id', np.array([], dtype=int)),
                'ang_amplify_val': ('comp_id', np.array([], dtype=float)),
                'set_ang_adjust': ('comp_id', np.array([], dtype=int)),
                'ang_adjust_val': ('comp_id', np.array([], dtype=float)),
                'set_hgt_amplify': ('comp_id', np.array([], dtype=int)),
                'hgt_amplify_val': ('comp_id', np.array([], dtype=float)),
                'set_berm_height': ('comp_id', np.array([], dtype=int)),
                'berm_height_val': ('comp_id', np.array([], dtype=float)),
                'set_depth_closure': ('comp_id', np.array([], dtype=int)),
                'depth_closure_val': ('comp_id', np.array([], dtype=float)),
                'tide_table': ('comp_id', np.array([], dtype=int)),
                'wave_table': ('comp_id', np.array([], dtype=int)),
                'wave_depth': ('comp_id', np.array([], dtype=float)),
                'point_type': ('comp_id', np.array([], dtype=object))
            }
            coords = {'comp_id': np.array([], dtype=int)}
            self._points = xr.Dataset(data_vars=point_table, coords=coords)

            self.commit()
            return True
        else:
            return False

    def update_point(self, comp_id, new_atts):
        """Update the point attributes of a point.

        Args:
            comp_id (:obj:`int`): Component id of the point to update
            new_atts (:obj:`xarray.Dataset`): The new attributes for the point
        """
        self.points['set_k1'].loc[dict(comp_id=[comp_id])] = new_atts['set_k1']
        self.points['k1_val'].loc[dict(comp_id=[comp_id])] = new_atts['k1_val']
        self.points['set_k2'].loc[dict(comp_id=[comp_id])] = new_atts['set_k2']
        self.points['k2_val'].loc[dict(comp_id=[comp_id])] = new_atts['k2_val']
        self.points['set_ang_amplify'].loc[dict(comp_id=[comp_id])] = new_atts['set_ang_amplify']
        self.points['ang_amplify_val'].loc[dict(comp_id=[comp_id])] = new_atts['ang_amplify_val']
        self.points['set_ang_adjust'].loc[dict(comp_id=[comp_id])] = new_atts['set_ang_adjust']
        self.points['ang_adjust_val'].loc[dict(comp_id=[comp_id])] = new_atts['ang_adjust_val']
        self.points['set_hgt_amplify'].loc[dict(comp_id=[comp_id])] = new_atts['set_hgt_amplify']
        self.points['hgt_amplify_val'].loc[dict(comp_id=[comp_id])] = new_atts['hgt_amplify_val']
        self.points['set_berm_height'].loc[dict(comp_id=[comp_id])] = new_atts['set_berm_height']
        self.points['berm_height_val'].loc[dict(comp_id=[comp_id])] = new_atts['berm_height_val']
        self.points['set_depth_closure'].loc[dict(comp_id=[comp_id])] = new_atts['set_depth_closure']
        self.points['depth_closure_val'].loc[dict(comp_id=[comp_id])] = new_atts['depth_closure_val']
        self.points['tide_table'].loc[dict(comp_id=[comp_id])] = new_atts['tide_table']
        self.points['wave_table'].loc[dict(comp_id=[comp_id])] = new_atts['wave_table']
        self.points['wave_depth'].loc[dict(comp_id=[comp_id])] = new_atts['wave_depth']
        self.points['point_type'].loc[dict(comp_id=[comp_id])] = new_atts['point_type']

    @property
    def points(self):
        """Load the points dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the points datasets in the main file
        """
        if self._points is None:
            self._points = self.get_dataset('points', False)
        return self._points

    @points.setter
    def points(self, dset):
        """Setter for the points attribute."""
        if dset:
            self._points = dset

    def add_point_atts(self, dset=None):
        """Add the point attribute dataset for a point.

        Args:
            dset (:obj:`xarray.Dataset`): The attribute dataset to concatenate. If not provided, a new Dataset
                of default attributes will be generated.

        Returns:
            (:obj:`tuple(int)`): The newly generated component id
        """
        try:
            new_comp_id = self.info.attrs['next_comp_id'].item()
            self.info.attrs['next_comp_id'] += 1  # Increment the unique XMS component id.
            if dset is None:  # Generate a new default Dataset
                dset = self._get_new_point_atts(new_comp_id)
            else:  # Update the component id of an existing Dataset
                dset.coords['comp_id'] = [new_comp_id for _ in dset.coords['comp_id']]
            self._points = xr.concat([self.points, dset], 'comp_id')
            return new_comp_id
        except Exception:
            return UNINITIALIZED_COMP_ID

    @staticmethod
    def _get_new_point_atts(comp_id):
        """Get a new dataset with default attributes for a point.

        Args:
            comp_id (:obj:`int`): The unique XMS component id of the save point. If UNINITIALIZED_COMP_ID, a new one is
                generated.

        Returns:
            (:obj:`xarray.Dataset`): A new default dataset for a BC arc. Can later be concatenated to persistent
            dataset.
        """
        point_table = {
            'set_k1': ('comp_id', np.array([0], dtype=int)),
            'k1_val': ('comp_id', np.array([0.4], dtype=float)),
            'set_k2': ('comp_id', np.array([0], dtype=int)),
            'k2_val': ('comp_id', np.array([0.25], dtype=float)),
            'set_ang_amplify': ('comp_id', np.array([0], dtype=int)),
            'ang_amplify_val': ('comp_id', np.array([0.0], dtype=float)),
            'set_ang_adjust': ('comp_id', np.array([0], dtype=int)),
            'ang_adjust_val': ('comp_id', np.array([0.0], dtype=float)),
            'set_hgt_amplify': ('comp_id', np.array([0], dtype=int)),
            'hgt_amplify_val': ('comp_id', np.array([1.0], dtype=float)),
            'set_berm_height': ('comp_id', np.array([0], dtype=int)),
            'berm_height_val': ('comp_id', np.array([1.0], dtype=float)),
            'set_depth_closure': ('comp_id', np.array([0], dtype=int)),
            'depth_closure_val': ('comp_id', np.array([10.0], dtype=float)),
            'tide_table': ('comp_id', np.array([0], dtype=int)),
            'wave_table': ('comp_id', np.array([0], dtype=int)),
            'wave_depth': ('comp_id', np.array([0.0], dtype=float)),
            'point_type': ('comp_id', np.array(['Attribute Modification'], dtype=object)),
        }
        coords = {'comp_id': [comp_id]}
        ds = xr.Dataset(data_vars=point_table, coords=coords)
        return ds

    def concat_points(self, pts_data):
        """Adds the points attributes from pts_data to this instance of PointsData.

        Args:
            pts_data (:obj:`PointsData`): another PointsData instance

        Returns:
            (:obj:`dict`): The old ids of the pts_data as key and the new ids as the data
        """
        next_comp_id = self.info.attrs['next_comp_id']
        # Reassign component id coordinates.
        new_points = pts_data.points
        num_concat_points = new_points.sizes['comp_id']
        if num_concat_points:
            old_comp_ids = new_points.coords['comp_id'].data.astype('i4').tolist()
            new_points.coords['comp_id'] = [next_comp_id + idx for idx in range(num_concat_points)]
            self.info.attrs['next_comp_id'] = next_comp_id + num_concat_points
            self._points = xr.concat([self.points, new_points], 'comp_id')
            return {
                old_comp_id: new_comp_id for old_comp_id, new_comp_id in
                zip(old_comp_ids, new_points.coords['comp_id'].data.astype('i4').tolist())
            }
        else:
            return {}

    def commit(self):
        """Save current in-memory component parameters to data file."""
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._points is not None:
            self._points.close()
            self._drop_h5_groups(['points'])
            self._points.to_netcdf(self._filename, group='points', mode='a')
        # write the wave tables
        for table_id, data in self._wave_table.items():
            grp = self._wave_table_group_name(table_id)
            self._drop_h5_groups([grp])
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')
        # write the tide tables
        for table_id, data in self._tide_table.items():
            grp = self._tide_table_group_name(table_id)
            self._drop_h5_groups([grp])
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space.

        All BC datasets that need to be written to the file must be loaded into memory before calling this method.
        """
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._points is None:
            self._points = self.get_dataset('points', False)
        for table in self._points['wave_table']:
            table = int(table)
            if table > 0:
                grp = self._wave_table_group_name(table)
                if table not in self._wave_table or self._wave_table[table] is None:
                    self._wave_table[table] = self.get_dataset(grp, False)
        for table in self._points['tide_table']:
            table = int(table)
            if table > 0:
                grp = self._tide_table_group_name(table)
                if table not in self._tide_table or self._tide_table[table] is None:
                    self._tide_table[table] = self.get_dataset(grp, False)
        io_util.removefile(self._filename)  # Delete the existing NetCDF file
        self.commit()  # Rewrite all datasets

    def wave_table_from_id(self, table_id, create_default=True):
        """Gets the wave table from the table id.

        Args:
            table_id (:obj:`int`): table id
            create_default (:obj:`bool`): True if a default table should be created if none is found.

        Returns:
            (:obj:`xarray.Dataset`): The wave table dataset
        """
        return self._curve_from_dictionary(table_id, self._wave_table, self.default_wave_table,
                                           self._wave_table_group_name, create_default)

    def tide_table_from_id(self, table_id, create_default=True):
        """Gets the tide table from the table id.

        Args:
            table_id (:obj:`int`): table id
            create_default (:obj:`bool`): True if a default table should be created if none is found.

        Returns:
            (:obj:`xarray.Dataset`): The tide table dataset
        """
        return self._curve_from_dictionary(table_id, self._tide_table, self.default_tide_table,
                                           self._tide_table_group_name, create_default)

    def set_wave_table(self, table_id, table):
        """Set the wave table.

        Args:
            table_id (:obj:`int`): The id of the table.
            table (:obj:`xarray.Dataset`): The dataset representing the table.
        """
        self._wave_table[table_id] = table

    def set_tide_table(self, table_id, table):
        """Set the tide table.

        Args:
          table_id (:obj:`int`): The id of the table.
          table (:obj:`xarray.Dataset`): The dataset representing the table.
        """
        self._tide_table[table_id] = table

    @staticmethod
    def _tide_table_group_name(table_id):
        """Gets the h5 group where the table is stored.

        Args:
            table_id (:obj:`int`): table id

        Returns:
            (:obj:`str`): The h5 group name
        """
        return f'tide_table/{table_id}'

    @staticmethod
    def _wave_table_group_name(table_id):
        """Gets the h5 group where the table is stored.

        Args:
            table_id (:obj:`int`): table id

        Returns:
            (:obj:`str`): The h5 group name
        """
        return f'wave_table/{table_id}'

    @staticmethod
    def default_tide_table():
        """Creates a xarray.Dataset for a tide table."""
        default_data = {
            'tide_date': [datetime.datetime(2000, 1, 1)],
            'current_speed': [0.0],
        }
        return pd.DataFrame(default_data).to_xarray()

    @staticmethod
    def default_wave_table():
        """Creates a xarray.Dataset for a wave table."""
        default_data = {
            'wave_date': [datetime.datetime(2000, 1, 1)],
            'wave_height': [0.0],
            'wave_period': [1.0],
            'wave_direction': [0.0],
        }
        return pd.DataFrame(default_data).to_xarray()

    def _curve_from_dictionary(self, curve_id, curve_dict, default_curve_method, curve_group_name_method,
                               create_default=True):
        """Gets the curve from the curve id.

        Args:
            curve_id (:obj:`int`): curve id
            curve_dict (:obj:`dict`): A dictionary of curve ids to xr.Dataset
            default_curve_method (:obj:`method`): The method to call to get a default curve.
            curve_group_name_method (:obj:`method`): The method ot call to get the name of the curve group.
            create_default (:obj:`bool`): True if a default curve should be created if none is found.

        Returns:
            (:obj:`xarray.Dataset`): The curve dataset
        """
        curve_id = int(curve_id)
        if curve_id not in curve_dict:
            # load from the file if the curve exists
            grp = curve_group_name_method(curve_id)
            dset = self.get_dataset(grp, False)
            if dset:
                column_names = dset.keys()
                rename_dict = {}
                for col in column_names:
                    if col.find('%slash%') >= 0:
                        rename_dict[col] = col.replace('%slash%', '/')
                if rename_dict:
                    dset = dset.rename(rename_dict)
            curve_dict[curve_id] = dset
            if curve_dict[curve_id] is None and create_default:
                # create a default curve
                curve_dict[curve_id] = default_curve_method()
        return curve_dict[curve_id]
