"""This handles saving and loading sediment transport constituents."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
from adhparam.data_frame_utils import create_default_dataframe
from adhparam.sediment_constituent_properties import CLAY_COLUMN_TYPES, SAND_COLUMN_TYPES, SedimentConstituentProperties
import h5py
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.core.filesystem.filesystem import removefile

# 4. Local modules
from xms.adh import __version__ as xms_adh_version
from xms.adh.data import param_h5_io
from xms.adh.data.version import needs_update

SAND_COLUMN_TYPES_PRE_1_5_0 = {
    'ID': 'int64',
    'CONCENTRATION': 'float',
    'GRAIN_DIAMETER': 'float',
    'SPECIFIC_GRAVITY': 'float',
    'POROSITY': 'float'
}

CLAY_COLUMN_TYPES_PRE_1_5_0 = {
    'ID': 'int64',
    'CONCENTRATION': 'float',
    'GRAIN_DIAMETER': 'float',
    'SPECIFIC_GRAVITY': 'float',
    'POROSITY': 'float',
    'CRITICAL_SHEAR_EROSION': 'float',
    'EROSION_RATE': 'float',
    'CRITICAL_SHEAR_DEPOSITION': 'float',
    'FREE_SETTLING_VELOCITY': 'float'
}

NAME_COLUMN_TYPES_PRE_1_5_0 = {'ID': 'int64', 'NAME': 'str'}


class SedimentConstituentsIO(XarrayBase):
    """
    A class that handles saving and loading sediment transport constituent data.

    Attributes:
        main_file (str): The main file associated with this component.
        param_control (SedimentConstituentProperties): The sediment constituent properties.
    """
    def __init__(self, main_file):
        """Initializes the data class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(str(main_file))
        self.main_file = str(main_file)
        attrs = self.info.attrs
        defaults = self.default_data()
        for key in defaults.keys():
            if key not in attrs:
                attrs[key] = defaults[key]
        attrs['FILE_TYPE'] = 'ADH_SEDIMENT_CONSTITUENTS'
        self.param_control = SedimentConstituentProperties()
        if os.path.exists(main_file):
            self.read_from_h5_file(main_file, 'ADH_SEDIMENT_CONSTITUENTS')

    def read_from_h5_file(self, filename, file_type):
        """Reads this class from an H5 file.

        Args:
            filename (str): file name
            file_type (str): The type of file expected. This will be compared to FILE_TYPE in the netCDF file.
        """
        grp_name = 'info'
        info = xr.load_dataset(filename, group=grp_name)
        if 'FILE_TYPE' not in info.attrs:
            raise IOError(f'File attributes do not include file_type in file : {filename}')
        ftype = info.attrs['FILE_TYPE']
        if ftype != file_type:
            raise IOError(f'Unrecognized file_type "{ftype}" attribute in file: {filename}')

        if needs_update(info.attrs['VERSION'], '1.5.0.dev1'):
            names = xr.load_dataset(filename, group='/names').to_dataframe()
            clay = param_h5_io.read_data_frame(filename, group_path='/clay/')
            if clay is not None:
                if 'BULK_DENSITY' in clay.columns:
                    clay.rename(columns={'BULK_DENSITY': 'POROSITY'}, inplace=True)
                clay_merged = clay.merge(names, on='ID', how='inner')
                clay_merged = clay_merged[list(CLAY_COLUMN_TYPES.keys())]
                self.param_control.clay = clay_merged.astype(CLAY_COLUMN_TYPES)
            sand = param_h5_io.read_data_frame(filename, group_path='/sand/')
            if sand is not None:
                if 'BULK_DENSITY' in sand.columns:
                    sand.rename(columns={'BULK_DENSITY': 'POROSITY'}, inplace=True)
                sand_merged = sand.merge(names, on='ID', how='inner')
                sand_merged = sand_merged[list(SAND_COLUMN_TYPES.keys())]
                self.param_control.sand = sand_merged.astype(SAND_COLUMN_TYPES)
            if needs_update(info.attrs['VERSION'], '1.3.0'):
                self.param_control.sand['GRAIN_DIAMETER'] = self.param_control.sand['GRAIN_DIAMETER'] / 1000
                self.param_control.clay['GRAIN_DIAMETER'] = self.param_control.clay['GRAIN_DIAMETER'] / 1000
        else:
            sand = param_h5_io.read_data_frame(filename, group_path='/sand/')
            if sand is not None:
                self.param_control.sand = sand.astype(SAND_COLUMN_TYPES)
            else:
                sand = create_default_dataframe(SAND_COLUMN_TYPES)
            self.param_control.sand = sand
            clay = param_h5_io.read_data_frame(filename, group_path='/clay/')
            if clay is not None:
                self.param_control.clay = clay.astype(CLAY_COLUMN_TYPES)
            else:
                clay = create_default_dataframe(CLAY_COLUMN_TYPES)
            self.param_control.clay = clay

    @staticmethod
    def default_data() -> dict[str, object]:
        """Gets the default data for this class.

        Returns:
            A dictionary of default values that will go into the info dataset attrs.
        """
        return {
            'FILE_TYPE': 'ADH_SEDIMENT_CONSTITUENTS',
            'VERSION': xms_adh_version,
            'next_constituent_id': 1,
        }

    def clean(self, group_name):
        """Removes the group from the file.

        Args:
            group_name (str): The h5 path to the group to be removed.
        """
        with h5py.File(self.main_file, 'a') as f:
            try:
                del f[group_name]
            except Exception:
                pass

    def commit(self):
        """Saves the component data to a netCDF file."""
        self.info.close()
        self.info.attrs['VERSION'] = xms_adh_version
        removefile(self.main_file)
        super().commit()
        param_h5_io.write_params_recursive(self.main_file, '/', self.param_control)
