"""This module defines data for the GenCade Structure hidden component."""

# 1. Standard Python modules
import datetime
import os

# 2. Third party modules
import numpy as np
import pandas as pd
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.gencade.components.id_files import UNINITIALIZED_COMP_ID

STRUCT_TYPES = ['Generic', 'Initial Shoreline', 'Regional Contour', 'Breakwater', 'Seawall', 'Groin',
                'Inlet', 'Left Jetty on Inlet', 'Right Jetty on Inlet', 'Bypass Event',
                'Beach Fill Event', 'Attachment Bar', 'SBAS Polygon', 'SBAS Flux', 'Reference Line']
TRANSMISSION_TYPES = ["Constant", "Ahren's", "Seabrook & Hall", "d'Angremond"]
SHOAL_TYPES = ['(none selected)', 'Left Attachment', 'Left Bypass', 'Ebb', 'Flood', 'Right Bypass', 'Right Attachment']


class StructData(XarrayBase):
    """Manages data file for the hidden save arcs component."""

    def __init__(self, data_file):
        """Initializes the data class.

        Args:
            data_file (:obj:`str`): The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.
        """
        self._filename = data_file
        self._info = None
        self._arcs = None
        self._dredging_tables = dict()
        self._bypass_tables = dict()
        self._beach_fill_tables = dict()
        self.deleted_comp_ids = set()  # Set this to be a list of deleted component ids in the coverage before vacuuming
        # Create the default file before calling super because we have our own attributes to write.
        created = self._get_default_datasets(data_file)
        super().__init__(data_file)
        if created:
            self.add_struct_arc_atts()
            self.commit()

    def _get_default_datasets(self, data_file):
        """Create default datasets if needed.

        Args:
            data_file (:obj:`str`): Name of the data file. If it doesn't exist, it will be created.

        Returns:
            (:obj:`bool`): Returns True if datasets were created.
        """
        if not os.path.exists(data_file) or not os.path.isfile(data_file):
            info = {
                'FILE_TYPE': 'GENCADE_STRUCT',
                # 'VERSION': pkg_resources.get_distribution('gencade').version,
                'cov_uuid': '',
                'arc_display_uuid': '',
                'next_comp_id': 0,
                'proj_dir': os.path.dirname(os.environ.get('XMS_PYTHON_APP_PROJECT_PATH', '')),
            }
            self._info = xr.Dataset(attrs=info)

            arc_table = {
                'struct_type_cbx': ('comp_id', np.array([], dtype=object)),
                'depth1': ('comp_id', np.array([], dtype=float)),
                'depth2': ('comp_id', np.array([], dtype=float)),
                'transmission_type_cbx': ('comp_id', np.array([], dtype=object)),
                'transmission_const': ('comp_id', np.array([], dtype=float)),
                'height': ('comp_id', np.array([], dtype=float)),
                'width': ('comp_id', np.array([], dtype=float)),
                'seaward_slope': ('comp_id', np.array([], dtype=float)),
                'shoreward_slope': ('comp_id', np.array([], dtype=float)),
                'bw_permeability': ('comp_id', np.array([], dtype=float)),
                'armor_d50': ('comp_id', np.array([], dtype=float)),
                'groin_permeability': ('comp_id', np.array([], dtype=float)),
                'diffracting_chk': ('comp_id', np.array([], dtype=int)),
                'seaward_depth': ('comp_id', np.array([], dtype=float)),
                'name': ('comp_id', np.array([], dtype=object)),
                'ebb_shoal_init': ('comp_id', np.array([], dtype=float)),
                'ebb_shoal_equil': ('comp_id', np.array([], dtype=float)),
                'flood_shoal_init': ('comp_id', np.array([], dtype=float)),
                'flood_shoal_equil': ('comp_id', np.array([], dtype=float)),
                'left_bypass_init': ('comp_id', np.array([], dtype=float)),
                'left_bypass_equil': ('comp_id', np.array([], dtype=float)),
                'left_bypass_coeff': ('comp_id', np.array([], dtype=float)),
                'left_attach_init': ('comp_id', np.array([], dtype=float)),
                'left_attach_equil': ('comp_id', np.array([], dtype=float)),
                'right_bypass_init': ('comp_id', np.array([], dtype=float)),
                'right_bypass_equil': ('comp_id', np.array([], dtype=float)),
                'right_bypass_coeff': ('comp_id', np.array([], dtype=float)),
                'right_attach_init': ('comp_id', np.array([], dtype=float)),
                'right_attach_equil': ('comp_id', np.array([], dtype=float)),
                'use_dredging_table': ('comp_id', np.array([], dtype=int)),
                'use_bypass_table': ('comp_id', np.array([], dtype=int)),
                'use_beach_fill_table': ('comp_id', np.array([], dtype=int)),
                'dredging_table': ('comp_id', np.array([], dtype=int)),
                'bypass_table': ('comp_id', np.array([], dtype=int)),
                'beach_fill_table': ('comp_id', np.array([], dtype=int)),
            }
            coords = {
                'comp_id': np.array([], dtype=int)
            }
            self._arcs = xr.Dataset(data_vars=arc_table, coords=coords)
            self.commit()
            return True
        else:
            return False

    @staticmethod
    def _clean_dataset_keys(dataset):
        """Removes '/' from the dataset keys for safe storage in h5.

        Args:
            dataset (:obj:`xarray.Dataset`): The dataset that may have bad keys.

        Returns:
            (:obj:`xarray.Dataset`): A new xr.Dataset with sanitized keys.
        """
        column_names = dataset.keys()
        rename_dict = {}
        for col in column_names:
            if col.find('/') >= 0:
                rename_dict[col] = col.replace('/', '%slash%')
        if rename_dict:
            dataset = dataset.rename(rename_dict)
        return dataset

    def commit(self):
        """Save current in-memory component parameters to data file."""
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._arcs is not None:
            self._arcs.close()
            self._drop_h5_groups(['arcs'])
            self._arcs.to_netcdf(self._filename, group='arcs', mode='a')
        # write the dredging tables
        for table_id, data in self._dredging_tables.items():
            grp = self.dredging_table_group_name(table_id)
            self._drop_h5_groups([grp])
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')
        # write the bypass tables
        for table_id, data in self._bypass_tables.items():
            grp = self.bypass_table_group_name(table_id)
            self._drop_h5_groups([grp])
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')
        # write the beach fill tables
        for table_id, data in self._beach_fill_tables.items():
            grp = self.beach_fill_table_group_name(table_id)
            self._drop_h5_groups([grp])
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')

    def vacuum(self):
        """Rewrite all StructData to a new/wiped file to reclaim disk space.

        All Struct datasets that need to be written to the file must be loaded into memory before calling this method.
        """
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._arcs is None:
            self._arcs = self.get_dataset('arcs', False)
        for table in self._arcs['dredging_table']:
            table = int(table)
            if table > 0:
                grp = self.dredging_table_group_name(table)
                if table not in self._dredging_tables or self._dredging_tables[table] is None:
                    self._dredging_tables[table] = self.get_dataset(grp, False)
        for table in self._arcs['bypass_table']:
            table = int(table)
            if table > 0:
                grp = self.bypass_table_group_name(table)
                if table not in self._bypass_tables or self._bypass_tables[table] is None:
                    self._bypass_tables[table] = self.get_dataset(grp, False)
        for table in self._arcs['beach_fill_table']:
            table = int(table)
            if table > 0:
                grp = self.beach_fill_table_group_name(table)
                if table not in self._beach_fill_tables or self._beach_fill_tables[table] is None:
                    self._beach_fill_tables[table] = self.get_dataset(grp, False)
        try:
            os.remove(self._filename)
        except Exception:
            pass
        self.commit()  # Rewrite all datasets

    @property
    def arcs(self):
        """Load the arcs dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the arcs datasets in the main file
        """
        if self._arcs is None:
            self._arcs = self.get_dataset('arcs', False)
        return self._arcs

    @arcs.setter
    def arcs(self, dset):
        """Setter for the arcs attribute."""
        if dset:
            self._arcs = dset

    def update_struct_arc(self, comp_id, new_atts):
        """Update the Structure arc attributes of a Structure arc.

        Args:
            comp_id (:obj:`int`): Component id of the Structure arc to update
            new_atts (:obj:`xarray.Dataset`): The new attributes for the Structure arc
        """
        self.arcs['struct_type_cbx'].loc[dict(comp_id=[comp_id])] = new_atts['struct_type_cbx']
        self.arcs['depth1'].loc[dict(comp_id=[comp_id])] = new_atts['depth1']
        self.arcs['depth2'].loc[dict(comp_id=[comp_id])] = new_atts['depth2']
        self.arcs['transmission_type_cbx'].loc[dict(comp_id=[comp_id])] = new_atts['transmission_type_cbx']
        self.arcs['transmission_const'].loc[dict(comp_id=[comp_id])] = new_atts['transmission_const']
        self.arcs['height'].loc[dict(comp_id=[comp_id])] = new_atts['height']
        self.arcs['width'].loc[dict(comp_id=[comp_id])] = new_atts['width']
        self.arcs['seaward_slope'].loc[dict(comp_id=[comp_id])] = new_atts['seaward_slope']
        self.arcs['shoreward_slope'].loc[dict(comp_id=[comp_id])] = new_atts['shoreward_slope']
        self.arcs['bw_permeability'].loc[dict(comp_id=[comp_id])] = new_atts['bw_permeability']
        self.arcs['armor_d50'].loc[dict(comp_id=[comp_id])] = new_atts['armor_d50']
        self.arcs['groin_permeability'].loc[dict(comp_id=[comp_id])] = new_atts['groin_permeability']
        self.arcs['diffracting_chk'].loc[dict(comp_id=[comp_id])] = int(new_atts['diffracting_chk'])
        self.arcs['seaward_depth'].loc[dict(comp_id=[comp_id])] = new_atts['seaward_depth']
        self.arcs['name'].loc[dict(comp_id=[comp_id])] = new_atts['name']
        self.arcs['ebb_shoal_init'].loc[dict(comp_id=[comp_id])] = new_atts['ebb_shoal_init']
        self.arcs['ebb_shoal_equil'].loc[dict(comp_id=[comp_id])] = new_atts['ebb_shoal_equil']
        self.arcs['flood_shoal_init'].loc[dict(comp_id=[comp_id])] = new_atts['flood_shoal_init']
        self.arcs['flood_shoal_equil'].loc[dict(comp_id=[comp_id])] = new_atts['flood_shoal_equil']
        self.arcs['left_bypass_init'].loc[dict(comp_id=[comp_id])] = new_atts['left_bypass_init']
        self.arcs['left_bypass_equil'].loc[dict(comp_id=[comp_id])] = new_atts['left_bypass_equil']
        self.arcs['left_bypass_coeff'].loc[dict(comp_id=[comp_id])] = new_atts['left_bypass_coeff']
        self.arcs['left_attach_init'].loc[dict(comp_id=[comp_id])] = new_atts['left_attach_init']
        self.arcs['left_attach_equil'].loc[dict(comp_id=[comp_id])] = new_atts['left_attach_equil']
        self.arcs['right_bypass_init'].loc[dict(comp_id=[comp_id])] = new_atts['right_bypass_init']
        self.arcs['right_bypass_equil'].loc[dict(comp_id=[comp_id])] = new_atts['right_bypass_equil']
        self.arcs['right_bypass_coeff'].loc[dict(comp_id=[comp_id])] = new_atts['right_bypass_coeff']
        self.arcs['right_attach_init'].loc[dict(comp_id=[comp_id])] = new_atts['right_attach_init']
        self.arcs['right_attach_equil'].loc[dict(comp_id=[comp_id])] = new_atts['right_attach_equil']
        self.arcs['use_dredging_table'].loc[dict(comp_id=[comp_id])] = new_atts['use_dredging_table']
        self.arcs['use_bypass_table'].loc[dict(comp_id=[comp_id])] = new_atts['use_bypass_table']
        self.arcs['use_beach_fill_table'].loc[dict(comp_id=[comp_id])] = new_atts['use_beach_fill_table']
        self.arcs['dredging_table'].loc[dict(comp_id=[comp_id])] = int(new_atts['dredging_table'])
        self.arcs['bypass_table'].loc[dict(comp_id=[comp_id])] = int(new_atts['bypass_table'])
        self.arcs['beach_fill_table'].loc[dict(comp_id=[comp_id])] = int(new_atts['beach_fill_table'])

        # changed to deal with strings with variable length
        df = self.arcs.to_dataframe()
        self.arcs = df.to_xarray()

    def add_struct_arc_atts(self, dset=None):
        """Add the struture arc attribute dataset for an arc.

        Args:
            dset (:obj:`xarray.Dataset`): The attribute dataset to concatenate. If not provided, a new Dataset of
                default attributes will be generated.

        Returns:
            (:obj:`tuple(int)`): The newly generated component id
        """
        try:
            new_comp_id = self.info.attrs['next_comp_id'].item()
            self.info.attrs['next_comp_id'] += 1  # Increment the unique XMS component id.
            if dset is None:  # Generate a new default Dataset
                dset = self._get_new_arc_atts(new_comp_id)
            else:  # Update the component id of an existing Dataset
                dset.coords['comp_id'] = [new_comp_id for _ in dset.coords['comp_id']]
            self._arcs = xr.concat([self.arcs, dset], 'comp_id')

            # changed to deal with strings with variable length
            df = self.arcs.to_dataframe()
            self.arcs = df.to_xarray()
            return new_comp_id
        except Exception:
            return UNINITIALIZED_COMP_ID

    @staticmethod
    def _get_new_arc_atts(comp_id):
        """Get a new dataset with default attributes for a structure arc.

        Args:
            comp_id (:obj:`int`): The unique XMS component id of the bc arc. If UNINITIALIZED_COMP_ID, a new one is
                generated.

        Returns:
            (:obj:`xarray.Dataset`): A new default dataset for BC arc. Can later be concatenated to persistent dataset.
        """
        arc_table = {
            'struct_type_cbx': ('comp_id', np.array(['Generic'], dtype=object)),
            'depth1': ('comp_id', np.array([0.1], dtype=float)),
            'depth2': ('comp_id', np.array([0.1], dtype=float)),
            'transmission_type_cbx': ('comp_id', np.array(['Constant'], dtype=object)),
            'transmission_const': ('comp_id', np.array([0.0], dtype=float)),
            'height': ('comp_id', np.array([0.0], dtype=float)),
            'width': ('comp_id', np.array([0.0], dtype=float)),
            'seaward_slope': ('comp_id', np.array([0.0], dtype=float)),
            'shoreward_slope': ('comp_id', np.array([0.0], dtype=float)),
            'bw_permeability': ('comp_id', np.array([0.0], dtype=float)),
            'armor_d50': ('comp_id', np.array([0.0], dtype=float)),
            'groin_permeability': ('comp_id', np.array([0.0], dtype=float)),
            'diffracting_chk': ('comp_id', np.array([0], dtype=int)),
            'seaward_depth': ('comp_id', np.array([1.0], dtype=float)),
            'name': ('comp_id', np.array(['(none selected)'], dtype=object)),
            'ebb_shoal_init': ('comp_id', np.array([0.0], dtype=float)),
            'ebb_shoal_equil': ('comp_id', np.array([0.0], dtype=float)),
            'flood_shoal_init': ('comp_id', np.array([0.0], dtype=float)),
            'flood_shoal_equil': ('comp_id', np.array([0.0], dtype=float)),
            'left_bypass_init': ('comp_id', np.array([0.0], dtype=float)),
            'left_bypass_equil': ('comp_id', np.array([0.0], dtype=float)),
            'left_bypass_coeff': ('comp_id', np.array([0.0], dtype=float)),
            'left_attach_init': ('comp_id', np.array([0.0], dtype=float)),
            'left_attach_equil': ('comp_id', np.array([0.0], dtype=float)),
            'right_bypass_init': ('comp_id', np.array([0.0], dtype=float)),
            'right_bypass_equil': ('comp_id', np.array([0.0], dtype=float)),
            'right_bypass_coeff': ('comp_id', np.array([0.0], dtype=float)),
            'right_attach_init': ('comp_id', np.array([0.0], dtype=float)),
            'right_attach_equil': ('comp_id', np.array([0.0], dtype=float)),
            'use_dredging_table': ('comp_id', np.array([0], dtype=int)),
            'use_bypass_table': ('comp_id', np.array([0], dtype=int)),
            'use_beach_fill_table': ('comp_id', np.array([0], dtype=int)),
            'dredging_table': ('comp_id', np.array([0], dtype=int)),
            'bypass_table': ('comp_id', np.array([0], dtype=int)),
            'beach_fill_table': ('comp_id', np.array([0], dtype=int)),
        }
        coords = {
            'comp_id': [comp_id]
        }
        ds = xr.Dataset(data_vars=arc_table, coords=coords)
        return ds

    def dredging_table_from_id(self, table_id, create_default=True):
        """Gets the dredging table from the table id.

        Args:
            table_id (:obj:`int`): table id
            create_default (:obj:`bool`): True if a default table should be created if none is found.

        Returns:
            (:obj:`xarray.Dataset`): The dredging table dataset
        """
        return self._curve_from_dictionary(table_id, self._dredging_tables, self.default_dredging_table,
                                           self.dredging_table_group_name, create_default)

    def bypass_table_from_id(self, table_id, create_default=True):
        """Gets the bypass table from the table id.

        Args:
            table_id (:obj:`int`): table id
            create_default (:obj:`bool`): True if a default table should be created if none is found.

        Returns:
            (:obj:`xarray.Dataset`): The bypass table dataset
        """
        return self._curve_from_dictionary(table_id, self._bypass_tables, self.default_bypass_table,
                                           self.bypass_table_group_name, create_default)

    def beach_fill_table_from_id(self, table_id, create_default=True):
        """Gets the beach fill table from the table id.

        Args:
            table_id (:obj:`int`): table id
            create_default (:obj:`bool`): True if a default table should be created if none is found.

        Returns:
            (:obj:`xarray.Dataset`): The beach fill table dataset
        """
        return self._curve_from_dictionary(table_id, self._beach_fill_tables, self.default_beach_fill_table,
                                           self.beach_fill_table_group_name, create_default)

    def set_dredging_table(self, table_id, table):
        """Set the dredging table.

        Args:
            table_id (:obj:`int`): The id of the table.
            table (:obj:`xarray.Dataset`): The dataset representing the table.
        """
        self._dredging_tables[table_id] = table

    def set_bypass_table(self, table_id, table):
        """Set the bypass table.

        Args:
            table_id (:obj:`int`): The id of the table.
            table (:obj:`xarray.Dataset`): The dataset representing the table.
        """
        self._bypass_tables[table_id] = table

    def set_beach_fill_table(self, table_id, table):
        """Set the beach fill table.

        Args:
            table_id (:obj:`int`): The id of the table.
            table (:obj:`xarray.Dataset`): The dataset representing the table.
        """
        self._beach_fill_tables[table_id] = table

    def _curve_from_dictionary(self, curve_id, curve_dict, default_curve_method, curve_group_name_method,
                               create_default=True):
        """Gets the curve from the curve id.

        Args:
            curve_id (:obj:`int`): curve id
            curve_dict (:obj:`dict`): A dictionary of curve ids to xr.Dataset
            default_curve_method (:obj:`method`): The method to call to get a default curve.
            curve_group_name_method (:obj:`method`): The method ot call to get the name of the curve group.
            create_default (:obj:`bool`): True if a default curve should be created if none is found.

        Returns:
            (:obj:`xarray.Dataset`): The curve dataset
        """
        curve_id = int(curve_id)
        if curve_id not in curve_dict:
            # load from the file if the curve exists
            grp = curve_group_name_method(curve_id)
            dset = self.get_dataset(grp, False)
            if dset:
                column_names = dset.keys()
                rename_dict = {}
                for col in column_names:
                    if col.find('%slash%') >= 0:
                        rename_dict[col] = col.replace('%slash%', '/')
                if rename_dict:
                    dset = dset.rename(rename_dict)
            curve_dict[curve_id] = dset
            if curve_dict[curve_id] is None and create_default:
                # create a default curve
                curve_dict[curve_id] = default_curve_method()
        return curve_dict[curve_id]

    @staticmethod
    def dredging_table_group_name(table_id):
        """Gets the h5 group where the table is stored.

        Args:
            table_id (:obj:`int`): table id

        Returns:
            (:obj:`str`): The h5 group name
        """
        return f'dredging_tables/{table_id}'

    @staticmethod
    def default_dredging_table():
        """Creates a xarray.Dataset for a dredging table."""
        default_data = {
            'begin_date': [datetime.date(2000, 1, 1)],
            'end_date': [datetime.date(2000, 1, 1)],
            'shoal_to_be_mined': ['(none selected)'],
            'volume': [0.0],
        }
        return pd.DataFrame(default_data).to_xarray()

    @staticmethod
    def bypass_table_group_name(table_id):
        """Gets the h5 group where the table is stored.

        Args:
            table_id (:obj:`int`): table id

        Returns:
            (:obj:`str`): The h5 group name
        """
        return f'bypass_tables/{table_id}'

    @staticmethod
    def default_bypass_table():
        """Creates a xarray.Dataset for a bypass table.

        Returns:
            (:obj:`pandas.DataFrame`)
        """
        default_data = {
            'start_date': [datetime.date(2000, 1, 1)],
            'end_date': [datetime.date(2000, 1, 1)],
            'bypass_rate': [0.0],
        }
        return pd.DataFrame(default_data).to_xarray()

    @staticmethod
    def beach_fill_table_group_name(table_id):
        """Gets the h5 group where the table is stored.

        Args:
            table_id (:obj:`int`): table id

        Returns:
            (:obj:`str`): The h5 group name
        """
        return f'beach_fill_tables/{table_id}'

    @staticmethod
    def default_beach_fill_table():
        """Creates a xarray.Dataset for a beach fill table."""
        default_data = {
            'start_date': [datetime.date(2000, 1, 1)],
            'end_date': [datetime.date(2000, 1, 1)],
            'added_berm_width': [0.0],
        }
        return pd.DataFrame(default_data).to_xarray()

    # External file reference stuff
    def update_file_paths(self):
        """Called before resaving an existing project.

        All referenced filepaths should be converted to relative from the project directory. Should already be stored
        in the component main file since this is a resave operation.

        Returns:
            (:obj:`str`): Message on failure, empty string on success
        """
        proj_dir = self.info.attrs['proj_dir']
        if not os.path.exists(proj_dir):
            return 'Unable to update selected file paths to relative from the project directory.'

        # TODO - This is left over from copying from CMS-Flow BC data. Is something similar needed?
        # variables = ['parent_adcirc_14', 'parent_adcirc_63', 'parent_adcirc_64', 'parent_cmsflow']
        variables = []
        flush_data = False
        for variable in variables:
            if self._update_table_files(self._arcs[variable], proj_dir, ''):
                flush_data = True
        if flush_data:
            self.commit()
        return ''
