"""This module defines data for the rubble mound structures hidden component."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
from xms.cmsflow.components.id_files import UNINITIALIZED_COMP_ID

CALCULATION_METHODS = ['Sidiropoulou et al. (2007)', 'Kadlec and Knight (1996)', 'Ward (1964)']
PARAMETER_TYPES = ['Constant', 'Dataset']
PARAMETER_TYPE_CONSTANT_IDX = 0
PARAMETER_TYPE_DATASET_IDX = 1


class RMStructuresData(XarrayBase):
    """Manages data file for the hidden rubble mound polygons component."""
    def __init__(self, data_file):
        """Initializes the data class.

        Args:
            data_file (str): The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.

        """
        self._filename = data_file
        self._info = None
        self._polygons = None
        self.deleted_comp_ids = set()  # Set this to be a list of deleted component ids in the coverage before vacuuming
        # Create the default file before calling super because we have our own attributes to write.
        created = self._get_default_datasets(data_file)
        super().__init__(data_file)
        self.info.attrs.pop('proj_dir', None)
        if created:
            self.add_rm_atts()
            self.commit()

    def _get_default_datasets(self, data_file):
        """Create default datasets if needed.

        Args:
            data_file (str): Name of the data file. If it doesn't exist, it will be created.

        Returns:
            Returns True if datasets were created.
        """
        if not os.path.exists(data_file) or not os.path.isfile(data_file):
            info = {
                'FILE_TYPE': 'CMSFLOW_RM_STRUCTURES',
                'cov_uuid': '',
                'polygon_display_uuid': '',
                'next_comp_id': 0,
            }
            self._info = xr.Dataset(attrs=info)

            poly_table = {
                'name': ('comp_id', np.array([], dtype=object)),
                'rock_diameter': ('comp_id', np.array([], dtype=float)),
                'rock_diameter_type': ('comp_id', np.array([], dtype=int)),
                # The spatially varying datasets variable names have to be all caps due to the way we import datasets.
                'ROCK_DIAMETER_DATASET': ('comp_id', np.array([], dtype=object)),
                'porosity': ('comp_id', np.array([], dtype=float)),
                'porosity_type': ('comp_id', np.array([], dtype=int)),
                'STRUCTURE_POROSITY_DATASET': ('comp_id', np.array([], dtype=object)),
                'base_depth': ('comp_id', np.array([], dtype=float)),
                'base_depth_type': ('comp_id', np.array([], dtype=int)),
                'STRUCTURE_BASE_DEPTH_DATASET': ('comp_id', np.array([], dtype=object)),
                'calculation_method': ('comp_id', np.array([], dtype=object))
            }
            coords = {'comp_id': np.array([], dtype=int)}
            self._polygons = xr.Dataset(data_vars=poly_table, coords=coords)
            self.commit()
            return True
        else:
            return False

    def commit(self):
        """Save current in-memory component parameters to data file.

        """
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._polygons is not None:
            self._polygons.close()
            self._drop_h5_groups(['polygons'])
            self._polygons.to_netcdf(self._filename, group='polygons', mode='a')

    def vacuum(self):
        """Rewrite all rubble mound data to a new/wiped file to reclaim disk space.

        All rubble mound datasets that need to be written to the file must be loaded into memory before calling
        this method.

        """
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._polygons is None:
            self._polygons = self.get_dataset('polygons', False)
        try:
            os.remove(self._filename)
        except Exception:
            pass
        self.commit()  # Rewrite all datasets

    @property
    def polygons(self):
        """Load the polygons dataset from disk.

        Returns:
            xarray.Dataset: Dataset interface to the polygons datasets in the main file

        """
        if self._polygons is None:
            self._polygons = self.get_dataset('polygons', False)
        return self._polygons

    @polygons.setter
    def polygons(self, dset):
        """Setter for the polygons attribute."""
        if dset:
            self._polygons = dset

    def update_rm_polygon(self, comp_id, new_atts):
        """Update the rubble mound attributes of a rubble mound polygon.

        Args:
            comp_id (int): Component id of the rubble mound polygon to update
            new_atts (xarray.Dataset): The new attributes for the rubble mound

        """
        self.polygons['name'].loc[dict(comp_id=[comp_id])] = new_atts['name']
        self.polygons['rock_diameter'].loc[dict(comp_id=[comp_id])] = new_atts['rock_diameter']
        self.polygons['rock_diameter_type'].loc[dict(comp_id=[comp_id])] = new_atts['rock_diameter_type']
        self.polygons['ROCK_DIAMETER_DATASET'].loc[dict(comp_id=[comp_id])] = new_atts['ROCK_DIAMETER_DATASET']
        self.polygons['porosity'].loc[dict(comp_id=[comp_id])] = new_atts['porosity']
        self.polygons['porosity_type'].loc[dict(comp_id=[comp_id])] = new_atts['porosity_type']
        dataset = new_atts['STRUCTURE_POROSITY_DATASET']
        self.polygons['STRUCTURE_POROSITY_DATASET'].loc[dict(comp_id=[comp_id])] = dataset
        self.polygons['base_depth'].loc[dict(comp_id=[comp_id])] = new_atts['base_depth']
        self.polygons['base_depth_type'].loc[dict(comp_id=[comp_id])] = new_atts['base_depth_type']
        dataset = new_atts['STRUCTURE_BASE_DEPTH_DATASET']
        self.polygons['STRUCTURE_BASE_DEPTH_DATASET'].loc[dict(comp_id=[comp_id])] = dataset
        self.polygons['calculation_method'].loc[dict(comp_id=[comp_id])] = new_atts['calculation_method']

    def add_rm_atts(self, dset=None):
        """Add the rubble mound polygon attribute dataset for a polygon.

        Args:
            dset (xarray.Dataset): The attribute dataset to concatenate. If not provided, a new Dataset of default
                attributes will be generated.


        Returns:
            (tuple of int): The newly generated component id

        """
        try:
            new_comp_id = self.info.attrs['next_comp_id'].item()
            self.info.attrs['next_comp_id'] += 1  # Increment the unique XMS component id.
            if dset is None:  # Generate a new default Dataset
                dset = self._get_new_polygon_atts(new_comp_id)
            else:  # Update the component id of an existing Dataset
                dset.coords['comp_id'] = [new_comp_id for _ in dset.coords['comp_id']]
            self._polygons = xr.concat([self.polygons, dset], 'comp_id')
            return new_comp_id
        except Exception:
            return UNINITIALIZED_COMP_ID

    def concat_rm_atts(self, rm_data):
        """Adds the save points attributes from rm_data to this instance of RMStructuresData.

        Args:
            rm_data (RMStructuresData): another RMStructuresData instance

        Returns:
            dict: The old ids of the rm_data as key and the new ids as the data

        """
        next_comp_id = self.info.attrs['next_comp_id']
        # Reassign component id coordinates.
        new_polys = rm_data.polygons
        num_concat_polys = new_polys.sizes['comp_id']
        if num_concat_polys:
            old_comp_ids = new_polys.coords['comp_id'].data.astype('i4').tolist()
            new_polys.coords['comp_id'] = [next_comp_id + idx for idx in range(num_concat_polys)]
            self.info.attrs['next_comp_id'] = next_comp_id + num_concat_polys
            self._polygons = xr.concat([self.polygons, new_polys], 'comp_id')
            return {
                old_comp_id: new_comp_id
                for old_comp_id, new_comp_id in
                zip(old_comp_ids, new_polys.coords['comp_id'].data.astype('i4').tolist())
            }
        else:
            return {}

    @staticmethod
    def _get_new_polygon_atts(comp_id):
        """Get a new dataset with default attributes for a rubble mound polygon.

        Args:
            comp_id (int): The unique XMS component id of the rubble mound polygon. If UNINITIALIZED_COMP_ID, a new one
            is generated.

        Returns:
            (xarray.Dataset): A new default dataset for a rubble mound polygon. Can later be concatenated to
            persistent dataset.

        """
        poly_table = {
            'name': ('comp_id', np.array(['(none selected)'], dtype=object)),
            'rock_diameter': ('comp_id', np.array([0.0], dtype=float)),
            'rock_diameter_type': ('comp_id', np.array([0], dtype=int)),
            'ROCK_DIAMETER_DATASET': ('comp_id', np.array([''], dtype=object)),
            'porosity': ('comp_id', np.array([0.4], dtype=float)),
            'porosity_type': ('comp_id', np.array([0], dtype=int)),
            'STRUCTURE_POROSITY_DATASET': ('comp_id', np.array([''], dtype=object)),
            'base_depth': ('comp_id', np.array([0.0], dtype=float)),
            'base_depth_type': ('comp_id', np.array([0], dtype=int)),
            'STRUCTURE_BASE_DEPTH_DATASET': ('comp_id', np.array([''], dtype=object)),
            'calculation_method': ('comp_id', np.array(['Sidiropoulou et al. (2007)'], dtype=object))
        }
        coords = {'comp_id': [comp_id]}
        ds = xr.Dataset(data_vars=poly_table, coords=coords)
        return ds
