"""Xarray data class for the EWN Feature coverage component."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import numpy as np
import pandas as pd
import pkg_resources
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem import filesystem as xmf

# 4. Local modules
from xms.ewn.data import ewn_cov_data_consts as ewn_consts


class EwnCovData(XarrayBase):
    """Class for storing the EWN coverage properties."""
    def __init__(self, filename):
        """Construct the data class."""
        super().__init__(filename)
        self._polygons = None  # polygon data set
        self._arcs = None  # arc data set
        self._cur_version = pkg_resources.get_distribution('xmsewn').version
        self._ensure_info_exists()

    def _ensure_info_exists(self):
        """Make sure all the info Dataset attrs are initialized."""
        self.info.attrs['FILE_TYPE'] = 'EWN_COVERAGE_DATA'
        if 'cov_uuid' not in self.info.attrs:
            self.info.attrs['cov_uuid'] = ''  # gets set later
        if 'display_uuid' not in self.info.attrs:
            self.info.attrs['display_uuid'] = ''
        if 'VERSION' not in self.info.attrs:
            self.info.attrs['VERSION'] = self._cur_version

    @property
    def meta_data(self):
        """Get ewn meta data from csv file.

        Returns:
            (:obj:`pandas.Dataframe`): the data
        """
        file_name = os.path.join(os.path.dirname(__file__), 'resources', 'EWN Table.csv')
        return pd.read_csv(file_name, float_precision='round_trip')

    @property
    def polygons(self):
        """Get the polygons data parameters.

        Returns:
            (:obj:`xarray.Dataset`): The polygons dataset

        """
        if self._polygons is None:
            self._polygons = self.get_dataset('polygons', False)
            if self._polygons is None:
                self._polygons = self._new_polygon_dataset()
        return self._polygons

    @polygons.setter
    def polygons(self, dset):
        """Set the polygons data parameters.

        Args:
            dset (:obj:`xarray.Dataset`): The new polygons Dataset

        """
        if dset is not None:
            self._polygons = dset

    @property
    def arcs(self):
        """Get the arcs data parameters.

        Returns:
            (:obj:`xarray.Dataset`): The arcs dataset

        """
        if self._arcs is None:
            self._arcs = self.get_dataset('arcs', False)
            if self._arcs is None:
                self._arcs = self._new_arc_dataset()
        return self._arcs

    @arcs.setter
    def arcs(self, dset):
        """Set the arcs data parameters.

        Args:
            dset (:obj:`xarray.Dataset`): The new arcs Dataset

        """
        if dset is not None:
            self._arcs = dset

    @property
    def has_polys(self):
        """Check if the polygons dataset is loaded into memory.

        Returns:
            (:obj:`bool`): True if the polygons dataset is loaded into memory
        """
        return self._polygons is not None

    @property
    def has_arcs(self):
        """Check if the arcs dataset is loaded into memory.

        Returns:
            (:obj:`bool`): True if the arcs dataset is loaded into memory
        """
        return self._arcs is not None

    def _new_polygon_dataset(self):
        """Creates a new polygon feature Dataset.

        Returns:
            (:obj:`xarray.Dataset`): An empty polygon feature dataset
        """
        # Construct the Dataset
        comp_id = 0
        if self._polygons is not None:
            comp_id = max(self._polygons.comp_id).item() + 1
        coords = {'comp_id': np.array([comp_id], dtype=np.int32)}
        poly_table = {
            'polygon_name': ('comp_id', np.array([''], dtype=np.unicode_)),
            'classification': ('comp_id', np.array([ewn_consts.FEATURE_TYPE_UNASSIGNED], dtype=np.int32)),
            'insert_feature': ('comp_id', np.array([1], dtype=np.int32)),
            'elevation_method': ('comp_id', np.array([ewn_consts.ELEVATION_METHOD_CONSTANT], dtype=np.int32)),
            'elevation': ('comp_id', np.array([0.0], dtype=np.float64)),
            'specify_slope': ('comp_id', np.array([0], dtype=np.int32)),
            'slope': ('comp_id', np.array([0.0], dtype=np.float64)),
            'maximum_slope_distance': ('comp_id', np.array([0.0], dtype=np.float64)),
            'manning_n': ('comp_id', np.array([0.0], dtype=np.float64)),
            'transition_method': ('comp_id', np.array([ewn_consts.TRANSITION_METHOD_FACTOR], dtype=np.int32)),
            'transition_distance': ('comp_id', np.array([0.5], dtype=np.float64)),
            'quadtree_refinement_length': ('comp_id', np.array([10], dtype=np.float64)),
        }
        return xr.Dataset(data_vars=poly_table, coords=coords)

    def _new_arc_dataset(self):
        """Creates a new arc feature Dataset.

        Returns:
            (:obj:`xarray.Dataset`): An empty arc feature dataset
        """
        # Construct the Dataset
        comp_id = 0
        if self._arcs is not None and len(self._arcs.comp_id) > 0:
            comp_id = max(self._arcs.comp_id).item() + 1
        coords = {'comp_id': np.array([comp_id], dtype=np.int32)}
        arc_table = {
            'arc_name': ('comp_id', np.array([''], dtype=np.unicode_)),
            'insert_feature': ('comp_id', np.array([0], dtype=np.int32)),
            'elevation': ('comp_id', np.array([0.0], dtype=np.float64)),
            'use_slope': ('comp_id', np.array([0], dtype=np.int32)),
            'side_slope': ('comp_id', np.array([0.5], dtype=np.float64)),
            'maximum_slope_distance': ('comp_id', np.array([0.5], dtype=np.float64)),
            'top_width': ('comp_id', np.array([2.0], dtype=np.float64)),
            'transition_method': ('comp_id', np.array([ewn_consts.TRANSITION_METHOD_FACTOR], dtype=np.int32)),
            'transition_distance': ('comp_id', np.array([0.5], dtype=np.float64)),
        }
        return xr.Dataset(data_vars=arc_table, coords=coords)

    def get_poly_atts(self, comp_id):
        """Gets a record from a component id. If the id is not known then returns None.

        Args:
            comp_id (:obj:`int`): The polygon's component id

        Returns:
            (:obj:`xarray.Dataset`): The record from the polygons dataset
        """
        if comp_id == ewn_consts.UNINITIALIZED_COMP_ID:
            return self._new_polygon_dataset()  # Polygon's component id has not been set yet, return default data.

        poly_data = self.polygons.where(self.polygons.comp_id == comp_id, drop=True)
        if len(poly_data.comp_id) == 0:  # Polygon's component id not found, return default data.
            return self._new_polygon_dataset()  # Return default data

        # Polygon's component id found in Dataset, populate the param object from existing attributes.
        return poly_data

    def get_arc_atts(self, comp_id):
        """Gets a record from a component id. If the id is not known then returns None.

        Args:
            comp_id (:obj:`int`): The arc's component id

        Returns:
            (:obj:`xarray.Dataset`): The record from the arcs dataset
        """
        if comp_id == ewn_consts.UNINITIALIZED_COMP_ID:
            return self._new_arc_dataset()  # Arc's component id has not been set yet, return default data.

        arc_data = self.arcs.where(self.arcs.comp_id == comp_id, drop=True)  # Arc's component id found in Dataset.
        if len(arc_data.comp_id) == 0:  # Arc's component id not found, return default data.
            return self._new_arc_dataset()  # Return default data

        # Arc's component id found in Dataset, populate the param object from existing attributes.
        return arc_data

    def add_polygon(self, poly_data):
        """Append the polygon feature attributes for a new polygon to the dataset.

        Args:
            poly_data (:obj:`xarray.Dataset`): Polygon attribute dataset

        Returns:
            (:obj:`tuple(int)`): The newly generated component id
        """
        try:
            if len(self.polygons.comp_id) == 0:  # Empty dataset
                new_comp_id = 0
            else:  # Generate a new, unique component id
                new_comp_id = int(max(self.polygons.comp_id.data)) + 1
            # Reset the comp_id coord for this data so we can append it to our dataset.
            poly_data = poly_data.assign_coords({'comp_id': np.array([new_comp_id], dtype=np.int32)})
            self._polygons = xr.concat([self.polygons, poly_data], 'comp_id')
            return new_comp_id
        except Exception:  # pragma no cover
            return ewn_consts.UNINITIALIZED_COMP_ID

    def add_arc(self, arc_data):
        """Append the arc feature attributes for a new arc to the dataset.

        Args:
            arc_data (:obj:`xarray.Dataset`): Arc attribute dataset

        Returns:
            (:obj:`tuple(int)`): The newly generated component id
        """
        try:
            if len(self.arcs.comp_id) == 0:  # Empty dataset
                new_comp_id = 0
            else:  # Generate a new, unique component id
                new_comp_id = int(max(self.arcs.comp_id.data)) + 1
            # Reset the comp_id coord for this data so we can append it to our dataset.
            arc_data = arc_data.assign_coords({'comp_id': np.array([new_comp_id], dtype=np.int32)})
            self._arcs = xr.concat([self.arcs, arc_data], 'comp_id')
            return new_comp_id
        except Exception:  # pragma no cover
            return ewn_consts.UNINITIALIZED_COMP_ID

    def remove_poly_comp_ids(self, delete_comp_ids):
        """Removes comp ids from the class that are not used any more.

        Args:
            delete_comp_ids (:obj:`iterable`): List-like containing the component ids to delete

        """
        if not delete_comp_ids:
            return
        # Drop unused component ids from the datasets
        mask = self.polygons.comp_id.isin(list(delete_comp_ids))
        self.polygons = self.polygons.where(~mask, drop=True)

    def remove_arc_comp_ids(self, delete_comp_ids):
        """Removes comp ids from the class that are not used any more.

        Args:
            delete_comp_ids (:obj:`iterable`): List-like containing the component ids to delete

        """
        if not delete_comp_ids:
            return
        # Drop unused component ids from the datasets
        mask = self.arcs.comp_id.isin(list(delete_comp_ids))
        self.arcs = self.arcs.where(~mask, drop=True)

    def close(self):
        """Closes the H5 file and does not write any data that is in memory."""
        super().close()
        if self._polygons is not None:
            self._polygons.close()
        if self._arcs is not None:
            self._arcs.close()

    def commit(self):
        """Save in memory datasets to the NetCDF file."""
        # Store the current package version.
        self.info.attrs['VERSION'] = self._cur_version
        self.close()
        super().commit()
        if self._polygons is not None:  # Only write polygon Dataset to disk if it has been loaded into memory.
            self._polygons.close()
            self._drop_h5_groups(['polygons'])
            self._polygons.to_netcdf(path=self._filename, group='polygons', mode='a')
        if self._arcs is not None:  # Only write arc Dataset to disk if it has been loaded into memory.
            self._arcs.close()
            self._drop_h5_groups(['arcs'])
            self._arcs.to_netcdf(path=self._filename, group='arcs', mode='a')

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space.

        All BC datasets that need to be written to the file must be loaded into memory before calling this method.
        """
        _ = self.info  # Ensure all data is loaded into memory.
        _ = self.polygons
        xmf.removefile(self._filename)  # Delete the existing NetCDF file.
        self.commit()  # Rewrite all datasets.
