"""MaterialData class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import pandas as pd
import pkg_resources

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem import filesystem

# 4. Local modules

DEFAULT_MATERIAL_COLORS = [
    (255, 0, 0),  # red
    (0, 0, 255),  # blue
    (0, 255, 51),  # green
    (255, 204, 0),  # yellow
    (0, 204, 255),  # cyan
    (255, 0, 204),  # magenta
    (153, 0, 255),  # purple
    (255, 102, 0),  # orange
    (51, 153, 51),  # dark green
    (153, 102, 0),  # brown
    (102, 153, 153),  # grey blue
    (153, 51, 51),  # dark red
    (255, 153, 255),  # light magenta
]


class MaterialData(XarrayBase):
    """Class for storing the SRH Material properties (Manning's N and sediment)."""
    UNASSIGNED_MAT = 0  # The unassigned material id (and row in material list)
    # constants for columns in the materials data set
    COL_ID = 0
    COL_COLOR = 1
    COL_NAME = 2
    COL_N = 3
    COL_DEPTH_TOG = 4
    COL_DEPTH_CURVE = 5

    COL_PROP_ID = 0
    COL_PROP_THICKNESS = 1
    COL_PROP_UNITS = 2
    COL_PROP_DENSITY = 3
    COL_PROP_GRADATION_CURVE = 4

    def __init__(self, filename):
        """Constructor.

        Args:
            filename (:obj:`str`): file name
        """
        super().__init__(filename)
        self.info.attrs['FILE_TYPE'] = 'SRH_MATERIAL_DATA'
        if 'cov_uuid' not in self.info.attrs:
            self.info.attrs['cov_uuid'] = ''  # gets set later
        if 'display_uuid' not in self.info.attrs:
            self.info.attrs['display_uuid'] = ''
        self._materials = None  # materials data set
        self._depth_curves = dict()  # depth dependent manning's n curves
        self._sediment_properties = dict()  # sediment atts for each material, dict[mat_id]
        self._gradation_curves = dict()  # sediment gradation curves, dict['mat_id att_idx']
        self._load_all_from_file()

    @property
    def materials(self):
        """Get the material component data parameters.

        Returns:
            (:obj:`xarray.Dataset`): The material list dataset

        """
        if self._materials is None:
            self._materials = self.get_dataset('materials', False)
            if self._materials is None:
                self._materials = self._default_materials()
            else:  # make sure the id column is an int
                self._materials.id.values = self._materials.id.astype(int).values
        return self._materials

    def set_materials(self, material_dataset):
        """Sets the material properties data set.

        Args:
            material_dataset (:obj:`xarray.Dataset`): the depth v roughness curve
        """
        self._materials = material_dataset

    def _default_materials(self):
        """Creates a default material data set.

        Returns:
            (:obj:`xarray.Dataset`): The material list dataset
        """
        default_data = {
            'id': [0],
            'Color and Texture': ['0 0 0 1'],
            'Name': ['unassigned'],
            "Manning's N": [0.02],
            'Depth Varied Curve': [0],
            'Curve': [''],
        }
        return pd.DataFrame(default_data).to_xarray()
        # import xarray as xr
        # default_data = {
        #     'id': xr.DataArray(data=np.array([0], dtype=np.int32)),
        #     'Color and Texture': xr.DataArray(data=np.array(['0 0 0 1'], dtype=np.unicode_)),
        #     'Name': xr.DataArray(data=np.array(['unassigned'], dtype=np.unicode_)),
        #     "Manning's N": xr.DataArray(data=np.array([0.02], dtype=np.float64)),
        #     'Depth Varied Curve': xr.DataArray(data=np.array([0], dtype=np.int32)),
        #     'Curve': xr.DataArray(data=np.array([''], dtype=np.unicode_)),
        # }
        # return xr.Dataset(data_vars=default_data)

    def depth_curve_from_mat_id(self, mat_id):
        """Gets the material depth curve from the material id.

        Args:
            mat_id (:obj:`int`): material id

        Returns:
            (:obj:`xarray.Dataset`): The material list dataset
        """
        if mat_id not in self._depth_curves:
            # load from the file if the curve exists
            grp = self._depth_curve_group_name(mat_id)
            self._depth_curves[mat_id] = self.get_dataset(grp, False)
            if self._depth_curves[mat_id] is None:
                # create a default curve
                self._depth_curves[mat_id] = self._default_depth_curve()
        return self._depth_curves[mat_id]

    def set_depth_curve(self, mat_id, curve):
        """Sets the depth curve by material id.

        Args:
            mat_id (:obj:`int`): material id
            curve (:obj:`xarray.Dataset`): the depth v roughness curve
        """
        self._depth_curves[mat_id] = curve

    def _depth_curve_group_name(self, mat_id):
        """Gets the h5 group where the curve is stored.

        Args:
            mat_id (:obj:`int`): material id

        Returns:
            (:obj:`str`): The h5 group name
        """
        return f'depth_curves/{mat_id}'

    def _default_depth_curve(self):
        """Creates a xarray.Dataset for a depth varied manning's n curve."""
        default_data = {
            'Depth': [0.0],
            "Manning's N": [0.0],
        }
        return pd.DataFrame(default_data).to_xarray()

    def sediment_properties_from_mat_id(self, mat_id):
        """Gets the material sediment properties.

        Args:
            mat_id (:obj:`int`): material id

        Returns:
            (:obj:`xarray.Dataset`): The material sediment dataset
        """
        if mat_id not in self._sediment_properties:
            # load from the file if the curve exists
            grp = self._sediment_property_group_name(mat_id)
            self._sediment_properties[mat_id] = self.get_dataset(grp, False)
            if self._sediment_properties[mat_id] is None:
                # create a default curve
                self._sediment_properties[mat_id] = self._default_sediment_properties()
        return self._sediment_properties[mat_id]

    def set_sediment_properties(self, mat_id, properties):
        """Sets the material sediment properties.

        Args:
            mat_id (:obj:`int`): material id
            properties (:obj:`xarray.Dataset`): sediment properties
        """
        self._sediment_properties[mat_id] = properties

    def _default_sediment_properties(self):
        """Creates a xarray.Dataset for a sediment properties."""
        default_data = {
            'att_id': [],
            'Thickness': [],
            'Units': [],
            'Density': [],
            'Gradation Curve': [],
        }
        return pd.DataFrame(default_data).to_xarray()

    def _sediment_property_group_name(self, mat_id):
        """Returns a string for the H5 group name for sediment properties."""
        return f'sediment_property/{mat_id}'

    def gradation_curve_from_mat_id_prop_id(self, mat_id, prop_id):
        """Get the sediment curve.

        Args:
            mat_id (:obj:`int`): material id
            prop_id (:obj:`int`): material sediment property id

        Returns:
            (:obj:`xarray.Dataset`): The sediment curve dataset
        """
        key = f'{mat_id} {prop_id}'
        if key not in self._gradation_curves:
            # load from the file if the curve exists
            grp = self._gradation_group_name(mat_id, prop_id)
            self._gradation_curves[key] = self.get_dataset(grp, False)
            if self._gradation_curves[key] is None:
                # create a default curve
                self._gradation_curves[key] = self._default_gradation_curve()
        return self._gradation_curves[key]

    def _default_gradation_curve(self):
        """Default gradation curve."""
        default_data = {
            'Particle diameter (mm)': [0.0],
            'Percent finer': [0.0],
        }
        return pd.DataFrame(default_data).to_xarray()

    def _gradation_group_name(self, mat_id, prop_id):
        """Gets the H5 group name string.

        Args:
            mat_id (:obj:`int`): material id
            prop_id (:obj:`int`): material sediment property id

        Returns:
            (:obj:`str`): The group name
        """
        return f'sediment_gradation_curves/mat_id_{mat_id}/prop_id_{prop_id}'

    def set_gradation_curve(self, mat_id, prop_id, curve):
        """Sets the material sediment properties.

        Args:
            mat_id (:obj:`int`): material id
            prop_id (:obj:`int`): material sediment property id
            curve (:obj:`xarray.Dataset`): gradation curve
        """
        key = f'{mat_id} {prop_id}'
        self._gradation_curves[key] = curve

    def add_materials(self, mat_data):
        """Adds the materials from mat_data to this instance of MaterialData.

        Args:
            mat_data (:obj:`MaterialData`): another Material Data instance

        Returns:
            (:obj:`dict`): The old ids of the mat_data as key and the new ids as the data
        """
        mat_df = self.materials.to_dataframe()
        set_mat_names = set(mat_df['Name'].to_list())
        mat_data_df = mat_data.materials.to_dataframe()
        old_to_new_ids = {0: 0}  # UNASSIGNED_MAT must always be in the dict
        next_mat_id = max(mat_df['id']) + 1
        next_idx = len(mat_df) + 1
        for i in range(1, len(mat_data_df)):  # start at 1 to skip UNASSIGNED_MAT
            mat_id = mat_data_df.iloc[i, MaterialData.COL_ID]
            new_mat_id = next_mat_id
            next_mat_id += 1
            old_to_new_ids[mat_id] = new_mat_id
            mat_data_df.iloc[i, MaterialData.COL_ID] = new_mat_id
            mat_name = mat_data_df.iloc[i, MaterialData.COL_NAME]
            cnt = 1
            while mat_name in set_mat_names:
                mat_name = f'{mat_name} ({cnt})'
                cnt += 1
            mat_data_df.iloc[i, MaterialData.COL_NAME] = mat_name
            if mat_id in mat_data._depth_curves:
                self._depth_curves[new_mat_id] = mat_data._depth_curves[mat_id]
            if mat_id in mat_data._sediment_properties:
                self._sediment_properties[new_mat_id] = mat_data._sediment_properties[mat_id]
                sed_df = mat_data._sediment_properties[mat_id].to_dataframe()
                for j in range(len(sed_df)):
                    prop_id = sed_df.iloc[0, j]
                    curve = mat_data.gradation_curve_from_mat_id_prop_id(mat_id, prop_id)
                    self.set_gradation_curve(new_mat_id, prop_id, curve)
            mat_df.loc[next_idx] = mat_data_df.values[i]
            next_idx += 1
        # mat_df = pd.concat([mat_df, mat_data_df])
        self._materials = mat_df.to_xarray()
        return old_to_new_ids

    def _load_all_from_file(self):
        """Loads all of the material data from the file into memory."""
        mats = self.materials.to_dataframe()
        for i in range(len(mats)):
            mat_id = int(mats.iloc[i, MaterialData.COL_ID])
            if mats.iloc[i, MaterialData.COL_DEPTH_TOG] != 0:  # using a depth curve
                self.depth_curve_from_mat_id(mat_id)
            sed = self.sediment_properties_from_mat_id(mat_id).to_dataframe()
            for j in range(len(sed)):
                prop_id = sed.iloc[j, MaterialData.COL_PROP_ID]
                self.gradation_curve_from_mat_id_prop_id(mat_id, prop_id)
        self.close()

    def commit(self):
        """Save in memory datasets to the NetCDF file."""
        self.info.attrs['VERSION'] = pkg_resources.get_distribution('xmssrh').version
        filesystem.removefile(self._filename)
        super().commit()

        # write
        self.materials.to_netcdf(self._filename, group='materials', mode='a')
        # write the depth curves
        for mat_id, data in self._depth_curves.items():
            grp = self._depth_curve_group_name(mat_id)
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')
        # write the sediment properties
        for mat_id, data in self._sediment_properties.items():
            grp = self._sediment_property_group_name(mat_id)
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')
        # write the sediment curves
        for key, data in self._gradation_curves.items():
            keys = key.split()
            mat_id = keys[0]
            prop_id = keys[1]
            grp = self._gradation_group_name(mat_id, prop_id)
            if data is not None:
                data.to_netcdf(self._filename, group=grp, mode='a')

    def close(self):
        """Closes the H5 file and does not write any data that is in memory."""
        super().close()
        if self._materials is not None:
            self._materials.close()
        self._close_dict_data(self._depth_curves)
        self._close_dict_data(self._sediment_properties)
        self._close_dict_data(self._gradation_curves)

    def _close_dict_data(self, data_dict):
        """Close xarray.Dataset items in a dict.

        Args:
            data_dict (:obj:`dict`): dictionary with xarray.Dataset
        """
        for data in data_dict.values():
            if data:
                data.close()
