"""This module is used for accessing sediment material data."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import copy
import importlib.metadata
import os
from pathlib import Path

# 2. Third party modules
from adhparam.material_transport_properties import MaterialTransportProperties
import h5py
from netCDF4 import Dataset
import pandas as pd
from PySide2.QtGui import QColor
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.guipy.data.polygon_texture import PolygonOptions

# 4. Local modules
from xms.adh.data.sediment_material import SedimentMaterial


class SedimentMaterialsIO(XarrayBase):
    """A class that handles saving and loading sediment material data."""

    UNASSIGNED_MAT = 0  # The unassigned material id (and row in material list)
    # Make the lists of combo box options.
    CSV_OPTIONS = ['Free settling', 'Hwang and Mehta']
    WWS_OPTIONS = ['No applied wind-wave stress', 'Grant and Madsen', 'Teeter']
    NSE_OPTIONS = ['No suspended load (bedload only)', 'Garcia-Parker', 'Wright-Parker', 'vanRijn', 'Yang']
    NBE_OPTIONS = [
        'No bedload (suspended load only)', 'VanRijn', 'Meyer Peter Mueller',
        'Meyer Peter Mueller with Wong-Parker correction', 'Wilcock'
    ]
    HID_OPTIONS = ['Karim Holly Yang', 'Egiazaroff', 'Wu Wang Jia', 'Parker and Klingeman']
    MP_NBL_OPTIONS = ['Bed layers defined with thickness', 'Bed layers defined with strata elevation horizon']

    def __init__(self, main_file: str | Path):
        """Initializes the data class.

        Args:
            main_file: The main file associated with this component.

        """
        super().__init__(str(main_file))
        self.main_file = str(main_file)
        attrs = self.info.attrs
        defaults = self.default_data()
        if not os.path.exists(main_file) or not os.path.isfile(main_file):
            self._info = xr.Dataset(attrs=defaults)
        for key in defaults.keys():
            if key not in attrs:
                attrs[key] = defaults[key]
        attrs['FILE_TYPE'] = 'ADH_SED_MAT_DATA'
        self.materials = {SedimentMaterialsIO.UNASSIGNED_MAT: SedimentMaterial(0)}
        self.materials[SedimentMaterialsIO.UNASSIGNED_MAT].name = 'unassigned'
        if os.path.exists(main_file):
            self.read_from_h5_file(self.main_file, 'ADH_SED_MAT_DATA')

    @staticmethod
    def get_material_ids(filename):
        """Reads the file and gets all of the top level groups.

        Args:
            filename (str): The name of the file to read.

        Returns:
            A list of strings representing the group names found in the file.
        """
        # check file type and version
        f = h5py.File(filename, 'r')
        materials = list(f.keys())
        f.close()
        return materials

    def read_from_h5_file(self, filename, file_type):
        """Reads this class from an h5 file.

        Args:
            filename (str): file name
            file_type (str): A description of what type of file we expect this to be. Should match text in the file.
        """
        materials = self.get_material_ids(filename)

        for id_key in materials:
            if id_key not in ['info']:
                mat_id = int(id_key)
                self.read_sediment_material_from_h5(filename, mat_id)

    def read_sediment_material_from_h5(self, filename, mat_id):
        """Reads an individual material from h5.

        Args:
            filename (str): The name of the netCDF file.
            mat_id (int): The id of the material to read.
        """
        grp_name = f'/{mat_id}/info'
        info = xr.load_dataset(filename, group=grp_name)
        if info.attrs is not None:
            if mat_id not in self.materials:
                self.materials[mat_id] = SedimentMaterial(mat_id)
            self.materials[mat_id].name = info.attrs['material_name']
            polygon_display = PolygonOptions()
            polygon_display.color = QColor(
                info.attrs['red'], info.attrs['green'], info.attrs['blue'], info.attrs['alpha']
            )
            polygon_display.texture = info.attrs['texture']
            self.materials[mat_id].display = polygon_display
            self.materials[mat_id].bed_layer_override = info.attrs['bed_layer_override'] != 0
            self.materials[mat_id].bed_layer_cohesive_override = info.attrs['bed_layer_cohesive_override'] != 0
            self.materials[mat_id].consolidation_override = info.attrs['consolidation_override'] != 0
            self.materials[mat_id].displacement_off = info.attrs['displacement_off'] != 0
            self.materials[mat_id].local_scour = info.attrs['local_scour'] != 0
            self.materials[mat_id].use_bedload_diffusion = info.attrs['use_bedload_diffusion'] != 0
            self.materials[mat_id].bedload_diffusion = info.attrs['bedload_diffusion']
            constituents_grp_name = f'/{mat_id}/constituents'
            bed_layers_grp_name = f'/{mat_id}/bed_layers'
            consolidation_grp_name = f'/{mat_id}/consolidation'
            sediment_mat_prop_grp_name = f'/{mat_id}/sediment_mat_prop'
            constituents = xr.load_dataset(filename, group=constituents_grp_name).to_dataframe()
            self.materials[mat_id].constituents = self._convert_percent_to_fraction(constituents)
            self.materials[mat_id].bed_layers = xr.load_dataset(filename, group=bed_layers_grp_name).to_dataframe()
            self.materials[mat_id].consolidation = xr.load_dataset(filename,
                                                                   group=consolidation_grp_name).to_dataframe()

            # Check if the sediment properties group exists
            sed_grp_exists = False
            with Dataset(filename) as netcdf_file:
                mat_grp = netcdf_file.groups[f'{mat_id}']
                grp_list = list(mat_grp.groups.keys())
                sed_grp_exists = 'sediment_mat_prop' in grp_list

            if sed_grp_exists:
                sed_df = xr.load_dataset(filename, group=sediment_mat_prop_grp_name).to_dataframe()
                new_dict = sed_df.to_dict()
                transport = {}
                for name in new_dict:
                    new_prop = MaterialTransportProperties()
                    new_prop.refinement_tolerance = new_dict[name]['refinement_tolerance']
                    new_prop.turbulent_diffusion_rate = new_dict[name]['turbulent_diffusion_rate']
                    transport[new_dict[name]['id']] = new_prop
                self.materials[mat_id].sediment_material_properties = transport

    @staticmethod
    def default_data():
        """Gets the default data for this class.

        Returns:
            A dictionary of default values that will go into the info dataset attrs.
        """
        # version = pkg_resources.require('xmsadh')[0].version
        version = importlib.metadata.version("xmsadh")
        return {
            'FILE_TYPE': 'ADH_MAT_DATA',
            'VERSION': version,
            'display_uuid': '',
            'cov_uuid': '',
            'next_comp_id': 0,
            'next_bed_layer_id': 1,
            'next_time_id': 1,
            'sediment_transport_uuid': '',
            'bed_layer_assignment_protocol': 0,
            'use_cohesive_bed_layers': 0,
            'use_consolidation': 0,
            'cohesive_settling_velocity_method': 'Free settling',
            'wind_wave_shear_method': 'No applied wind-wave stress',
            'noncohesive_suspended_method': 'No suspended load',
            'noncohesive_bedload_method': 'No bedload (suspended load only)',
            'critical_shear_sand': 0.0,
            'critical_shear_clay': 0.0,
            'noncohesive_hiding_method': 'Karim Holly Yang',
            'hiding_factor': 0.5,
            'use_sediment_infiltration_factor': 0,
            'sediment_infiltration_factor': 1.0,
            'a_csv': 0.0,
            'b_csv': 0.0,
            'm_csv': 0.0,
            'n_csv': 0.0
        }

    @property
    def display_uuid(self) -> str:
        """Display UUID getter.

        Returns:
            The display UUID.
        """
        return self.info.attrs["display_uuid"]

    @display_uuid.setter
    def display_uuid(self, value: str):
        """Display UUID setter.

        Args:
            value (str): The new UUID value to be set for the `display_uuid`.
        """
        self.info.attrs["display_uuid"] = value

    @property
    def cov_uuid(self) -> str:
        """Coverage UUID getter.

        Returns:
            The coverage UUID.
        """
        return self.info.attrs["cov_uuid"]

    @cov_uuid.setter
    def cov_uuid(self, value: str):
        """Coverage UUID setter.

        Args:
            value (str): The new UUID value to be set for the `cov_uuid`.
        """
        self.info.attrs["cov_uuid"] = value

    @property
    def sediment_transport_uuid(self) -> str:
        """Sediment transport UUID getter.

        Returns:
            The sediment transport UUID.
        """
        return self.info.attrs["sediment_transport_uuid"]

    @sediment_transport_uuid.setter
    def sediment_transport_uuid(self, value: str):
        """Sediment transport UUID setter.

        Args:
            value: The sediment transport UUID.
        """
        self.info.attrs["sediment_transport_uuid"] = value

    def delete_materials_for_commit(self, deleted_ids):
        """Deletes materials from the main file before committing.

        Args:
            deleted_ids (list): A list of integer material ids to delete.
        """
        with h5py.File(self.main_file, 'a') as file:
            try:
                for id in deleted_ids:
                    del file[f'/{id}/']
            except Exception:
                pass  # Try to clean up the other groups

    def add_material(self):
        """Adds a new material.

        Returns:
            The new material id.
        """
        number = max(self.materials, key=int) + 1
        self.materials[number] = copy.deepcopy(self.materials[0])
        self.materials[number].bed_layer_cohesive_override = self.info.attrs['use_cohesive_bed_layers']
        self.materials[number].name = f'Sediment Material {number}'
        return number

    def commit(self):
        """Save the data to a netCDF file."""
        # self.info.attrs['VERSION'] = pkg_resources.require('xmsadh')[0].version
        self.info.attrs['VERSION'] = importlib.metadata.version("xmsadh")
        super().commit()

        # commit each material
        for mat_id in self.materials.keys():
            grp_name = f'/{mat_id}/info'
            constituents_grp_name = f'/{mat_id}/constituents'
            bed_layers_grp_name = f'/{mat_id}/bed_layers'
            consolidation_grp_name = f'/{mat_id}/consolidation'
            sediment_mat_prop_grp_name = f'/{mat_id}/sediment_mat_prop'
            with h5py.File(self.main_file, 'a') as f:
                try:
                    del f[constituents_grp_name]
                except Exception:
                    pass  # Try to clean up the other groups
                try:
                    del f[bed_layers_grp_name]
                except Exception:
                    pass  # Try to clean up the other groups
                try:
                    del f[consolidation_grp_name]
                except Exception:
                    pass  # Try to clean up the other groups
                try:
                    del f[sediment_mat_prop_grp_name]
                except Exception:
                    pass  # Try to clean up the other groups
            material_info = xr.Dataset()
            material_info.attrs['material_name'] = self.materials[mat_id].name
            color = self.materials[mat_id].display.color
            material_info.attrs['red'] = color.red()
            material_info.attrs['green'] = color.green()
            material_info.attrs['blue'] = color.blue()
            material_info.attrs['alpha'] = color.alpha()
            material_info.attrs['texture'] = self.materials[mat_id].display.texture
            material_info.attrs['bed_layer_override'] = 1 if self.materials[mat_id].bed_layer_override else 0
            material_info.attrs['bed_layer_cohesive_override'] = \
                1 if self.materials[mat_id].bed_layer_cohesive_override else 0
            material_info.attrs['consolidation_override'] = 1 if self.materials[mat_id].consolidation_override else 0
            material_info.attrs['displacement_off'] = 1 if self.materials[mat_id].displacement_off else 0
            material_info.attrs['local_scour'] = 1 if self.materials[mat_id].local_scour else 0
            material_info.attrs['use_bedload_diffusion'] = 1 if self.materials[mat_id].use_bedload_diffusion else 0
            material_info.attrs['bedload_diffusion'] = self.materials[mat_id].bedload_diffusion
            material_info.to_netcdf(self.main_file, group=grp_name, mode='a')
            mat_constituents = self.materials[mat_id].constituents.astype(
                {
                    'layer_id': 'int64',
                    'constituent_id': 'int64'
                }
            )
            mat_constituents.to_xarray().to_netcdf(self.main_file, group=constituents_grp_name, mode='a')
            mat_bed_layer = self.materials[mat_id].bed_layers.astype({'layer_id': 'int64'})
            mat_bed_layer.to_xarray().to_netcdf(self.main_file, group=bed_layers_grp_name, mode='a')
            mat_consolidation = self.materials[mat_id].consolidation.astype({'time_id': 'int64'})
            mat_consolidation.to_xarray().to_netcdf(self.main_file, group=consolidation_grp_name, mode='a')

            sediment_mat_prop = self.materials[mat_id].sediment_material_properties
            new_dict = {
                property.name:
                    {
                        'id': key,
                        'refinement_tolerance': property.refinement_tolerance,
                        'turbulent_diffusion_rate': property.turbulent_diffusion_rate
                    }
                for key, property in sediment_mat_prop.items()
            }

            sed_df = pd.DataFrame(new_dict)
            sed_df.to_xarray().to_netcdf(self.main_file, group=sediment_mat_prop_grp_name, mode='a')

    def update_from_global(self):
        """Updates the materials to have the same number of bed layers and consolidation times."""
        layer_ids = self.materials[0].bed_layers['layer_id'].values.tolist()
        time_ids = self.materials[0].consolidation['time_id'].values.tolist()
        for mat_id, material in self.materials.items():
            if mat_id == 0:
                continue
            material.update_consolidation_times(time_ids, self.materials[0].consolidation)
            material.update_bed_layers(layer_ids, self.materials[0].bed_layers)

    def _convert_percent_to_fraction(self, data_frame: pd.DataFrame):
        """Convert bed_layers percent to a fraction."""
        # Check if "percent" column exists
        if 'percent' in data_frame.columns:
            data_frame['fraction'] = data_frame['percent']
            data_frame.drop(columns=['percent'], inplace=True)
        return data_frame
