"""This handles changes in the hydrodynamic materials."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import importlib.metadata
import os

# 2. Third party modules
from adhparam.material_properties import MaterialProperties
from adhparam.material_transport_properties import MaterialTransportProperties
from adhparam.time_series import TimeSeries
import h5py
from PySide2.QtGui import QColor
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.guipy.data.polygon_texture import PolygonOptions

# 4. Local modules
from xms.adh.data.json_export import filtered_info_attrs, make_json_serializable
from xms.adh.data.materials import Materials
from xms.adh.data.param_h5_io import read_params_recursive, write_params_recursive


class MaterialsIO(XarrayBase):
    """A class that handles saving and loading material data."""

    UNASSIGNED_MAT = 0  # The OFF material id (and row in material list)

    def __init__(self, main_file):
        """Initializes the data class.

        Args:
            main_file: The main file associated with this component.
        """
        super().__init__(main_file)
        self.main_file = main_file
        attrs = self.info.attrs
        defaults = self.default_data()
        if not os.path.exists(main_file) or not os.path.isfile(main_file):
            self._info = xr.Dataset(attrs=defaults)
        for key in defaults.keys():
            if key not in attrs:
                attrs[key] = defaults[key]
        attrs['FILE_TYPE'] = 'ADH_MAT_DATA'
        self.materials = Materials()
        if os.path.exists(main_file):
            self.read_from_h5_file(main_file, 'ADH_MAT_DATA')

    def get_material_ids(self, filename):
        """Gets the current material ids.

        Returns:
            A list of current material ids.
        """
        # check file type and version
        f = h5py.File(filename, 'r')
        materials = list(f.keys())
        f.close()
        return materials

    def read_from_h5_file(self, filename, file_type):
        """Reads this class from an h5 file.

        Args:
            filename (str): file name
            file_type (str): A description of what type of file we expect this to be. Should match text in the file.
        """
        materials = self.get_material_ids(filename)
        self.materials.friction = xr.open_dataset(filename, group='/material_friction').to_dataframe()
        grp_name = 'info'
        info = xr.load_dataset(filename, group=grp_name)
        if 'FILE_TYPE' not in info.attrs:
            raise IOError(f'File attributes do not include file_type in file : {filename}')
        ftype = info.attrs['FILE_TYPE']
        if ftype != file_type:
            raise IOError(f'Unrecognized file_type "{ftype}" attribute in file: {filename}')

        for id_key in materials:
            if id_key not in ['info', 'material_friction', 'series']:
                mat_id = int(id_key)
                self.materials.material_properties[mat_id] = MaterialProperties()
                id_group = f'/{id_key}/'
                read_params_recursive(
                    filename, group_name=id_group, param_class=self.materials.material_properties[mat_id]
                )
                self.read_name_and_display_from_h5(filename, mat_id)

        self._read_transport_constituents(filename)

        series_ids = []
        with h5py.File(filename, 'r') as file:
            if '/series' in file:
                series_groups = file['/series'].keys()
                series_ids = [int(group) for group in series_groups]
        for series_id in series_ids:
            id_group = f'/series/{series_id}/'
            self.materials.time_series[series_id] = TimeSeries()
            read_params_recursive(filename, group_name=id_group, param_class=self.materials.time_series[series_id])

    def _read_transport_constituents(self, filename):
        """Reads transport constituents for materials from an H5 file.

        Material ids should be read before this is called.

        Args:
            filename (str): file name
        """
        transport_keys = []
        with h5py.File(filename, 'r') as file:
            for mat_id in self.materials.material_properties.keys():
                trans_grp_name = f'/{mat_id}/transport_properties/'
                if trans_grp_name in file:
                    material = self.materials.material_properties[mat_id]
                    if not transport_keys:
                        # the constituent ids should be the same for all materials (if present)
                        transport_keys = [int(key) for key in file[trans_grp_name].keys()]
                    for constituent_id in transport_keys:
                        con_grp_name = f'/{mat_id}/transport_properties/{constituent_id}/info'
                        con_info = xr.load_dataset(filename, group=con_grp_name)
                        if con_info.attrs is not None:
                            material.transport_properties[constituent_id] = MaterialTransportProperties()
                            material.transport_properties[constituent_id].refinement_tolerance = \
                                con_info.attrs['refinement_tolerance']
                            material.transport_properties[constituent_id].turbulent_diffusion_rate = \
                                con_info.attrs['turbulent_diffusion_rate']
                            material.transport_properties[constituent_id].use_diffusion_coefficient = \
                                bool(con_info.attrs.get('use_diffusion_coefficient', False))

    def read_name_and_display_from_h5(self, filename, mat_id):
        """Reads material names and display attributes.

        Args:
            filename (str): The main file.
            mat_id (int): The material id.
        """
        grp_name = f'/{mat_id}/info'
        info = xr.load_dataset(filename, group=grp_name)
        if info.attrs is not None:
            polygon_display = PolygonOptions()
            polygon_display.color = QColor(
                info.attrs['red'], info.attrs['green'], info.attrs['blue'], info.attrs['alpha']
            )
            polygon_display.texture = info.attrs['texture']
            self.materials.material_display[mat_id] = polygon_display
            if 'use_seasonal' in info.attrs:
                self.materials.friction_use_seasonal[mat_id] = info.attrs['use_seasonal'] != 0
                self.materials.friction_seasonal_curve[mat_id] = int(info.attrs['seasonal_curve'])
            self.materials.material_use_meteorological[mat_id] = info.attrs['use_meteorological'] != 0
            self.materials.material_meteorological_curve[mat_id] = int(info.attrs['meterological_curve'])

    def default_data(self):
        """Gets the default data for this class.

        Returns:
            A dictionary of default values that will go into the info dataset attrs.
        """
        version = importlib.metadata.version('xmsadh')
        return {
            'FILE_TYPE': 'ADH_MAT_DATA',
            'VERSION': version,
            'display_uuid': '',
            'cov_uuid': '',
            'next_comp_id': 0,
            'transport_uuid': '',
            'use_transport': 0
        }

    @property
    def display_uuid(self) -> str:
        """Display UUID getter.

        Returns:
            The display UUID.
        """
        return self.info.attrs["display_uuid"]

    @display_uuid.setter
    def display_uuid(self, value: str):
        """Display UUID setter.

        Args:
            value (str): The new UUID value to be set for the `display_uuid`.
        """
        self.info.attrs["display_uuid"] = value

    @property
    def cov_uuid(self) -> str:
        """Coverage UUID getter.

        Returns:
            The coverage UUID.
        """
        return self.info.attrs["cov_uuid"]

    @cov_uuid.setter
    def cov_uuid(self, value: str):
        """Coverage UUID setter.

        Args:
            value (str): The new UUID value to be set for the `cov_uuid`.
        """
        self.info.attrs["cov_uuid"] = value

    @property
    def transport_uuid(self) -> str:
        """Transport UUID getter.

        Returns:
            The transport UUID.
        """
        return self.info.attrs["transport_uuid"]

    @transport_uuid.setter
    def transport_uuid(self, value: str):
        """Transport UUID setter.

        Args:
            value: The transport UUID.
        """
        self.info.attrs["transport_uuid"] = value

    @property
    def use_transport(self) -> int:
        """Use transport getter.

        Returns:
            If transport is used.
        """
        return self.info.attrs["use_transport"]

    @use_transport.setter
    def use_transport(self, value: int):
        """Use transport setter.

        Args:
            value: If transport is used (0 or 1).
        """
        self.info.attrs["use_transport"] = value

    def delete_materials_for_commit(self, deleted_ids):
        """Deletes materials from the main file prior to a commit.

        Args:
            deleted_ids (list): A list of int material ids to delete from the file.
        """
        with h5py.File(self.main_file, 'a') as file:
            try:
                for id in deleted_ids:
                    del file[f'/{id}/']
            except Exception:
                pass  # Try to clean up the other groups

    def commit(self):
        """Stores material data in the main file."""
        self.info.attrs['VERSION'] = importlib.metadata.version('xmsadh')
        super().commit()
        for mat_id in self.materials.material_properties:
            id_group = f'/{mat_id}/'
            grp_name = f'/{mat_id}/info'
            write_params_recursive(self.main_file, id_group, self.materials.material_properties[mat_id])
            with h5py.File(self.main_file, 'a') as f:
                try:
                    del f[f'/{mat_id}/transport_properties']
                except Exception:
                    pass  # Try to clean up the other groups
            for constituent_id in self.materials.material_properties[mat_id].transport_properties.keys():
                transport_group_name = f'/{mat_id}/transport_properties/{constituent_id}/'
                write_params_recursive(
                    self.main_file, transport_group_name,
                    self.materials.material_properties[mat_id].transport_properties[constituent_id]
                )
            series_info = xr.Dataset()
            color = self.materials.material_display[mat_id].color
            series_info.attrs['red'] = color.red()
            series_info.attrs['green'] = color.green()
            series_info.attrs['blue'] = color.blue()
            series_info.attrs['alpha'] = color.alpha()
            series_info.attrs['texture'] = int(self.materials.material_display[mat_id].texture)
            if mat_id in self.materials.friction_use_seasonal:
                series_info.attrs['friction_use_seasonal'] = int(self.materials.friction_use_seasonal[mat_id])
            else:
                series_info.attrs['friction_use_seasonal'] = 0
            if mat_id in self.materials.friction_seasonal_curve:
                series_info.attrs['seasonal_curve'] = int(self.materials.friction_seasonal_curve[mat_id])
            else:
                series_info.attrs['seasonal_curve'] = 0
            series_info.attrs['use_meteorological'] = 1 if self.materials.material_use_meteorological[mat_id] else 0
            series_info.attrs['meterological_curve'] = int(self.materials.material_meteorological_curve[mat_id])
            series_info.to_netcdf(self.main_file, group=grp_name, mode='a')
        mat_friction_xarray = self.materials.friction.to_xarray()
        with h5py.File(self.main_file, 'a') as f:
            try:
                del f['/material_friction']
            except Exception:
                pass  # Try to clean up the other groups
        idx_list = [i for i in range(len(mat_friction_xarray.index))]
        mat_friction_xarray = mat_friction_xarray.assign_coords({'index': idx_list})
        mat_friction_xarray.to_netcdf(self.main_file, group='/material_friction', mode='a')

        for series_id in self.materials.time_series:
            id_group = f'/series/{series_id}/'
            write_params_recursive(self.main_file, id_group, self.materials.time_series[series_id])

    def as_dict(self) -> dict:
        """Converts the material data to a dictionary.

        Returns:
            A dictionary representation of the material data.
        """
        result = {
            'info': filtered_info_attrs(self.info.attrs),
            'materials': self.materials.as_dict()
        }
        result = make_json_serializable(result)
        return result
