"""This module provides ways to read and write param values to h5."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from dataclasses import fields, is_dataclass
from pathlib import Path
from typing import Any

# 2. Third party modules
import h5py
import pandas as pd
import param
import xarray as xr

# 3. Aquaveo modules

# 4. Local modules
from xms.adh.data import param_util


def read_from_h5_file(filename, param_cls, file_type):
    """Reads this class from an H5 file.

    Args:
        filename (str): file name
        param_cls (param.Parameterized): param class
        file_type (str): The file type attribute for netcdf.
    """
    # check file type and version
    grp_name = 'info'
    info = xr.load_dataset(filename, group=grp_name)
    if 'FILE_TYPE' not in info.attrs:
        raise IOError(f'File attributes do not include file_type in file : {filename}')
    ftype = info.attrs['FILE_TYPE']
    if ftype != file_type:
        raise IOError(f'Unrecognized file_type "{ftype}" attribute in file: {filename}')

    read_params_recursive(filename, group_name='/', param_class=param_cls)


def write_params_recursive(filename, group_name, param_class):
    """Recursively writes param classes to an H5 file.

    Args:
        filename (str): file name
        group_name (str): h5 group
        param_class (param.Parameterized): param class
    """
    # create a group for this class
    info = xr.Dataset()

    # get all params in this class
    param_names, params = param_util.names_and_params_from_class(param_class)
    supported_types = [
        param.FileSelector, param.Integer, param.Number, param.ObjectSelector, param.parameterized.String
    ]
    for i in range(len(param_names)):
        param_name = param_names[i]
        if param_name == 'name':
            continue
        value = getattr(param_class, param_name)
        if value is None:
            continue
        # obj = getattr(cls, member[0])
        # for some reason call getattr for a param.DataFrame messes up the data frame
        if type(value) is pd.DataFrame:
            df_group_path = f'{group_name}{param_name}'
            write_dataframe_to_h5(filename, df_group_path, value)
            continue

        obj_type = type(params[i])
        if obj_type in supported_types:
            info.attrs[param_name] = value
        elif obj_type is param.Boolean:
            info.attrs[param_name] = 1 if value else 0
        elif obj_type is param.ClassSelector:
            val = getattr(param_class, param_name)
            if issubclass(type(val), param.Parameterized):
                group_path = f'{group_name}{param_name}/'
                write_params_recursive(filename, group_path, val)
        else:
            raise NotImplementedError

    info_group_path = f'{group_name}info'
    info.to_netcdf(filename, group=info_group_path, mode='a')


def write_dataframe_to_h5(filename: str, group_path: str, value: pd.DataFrame):
    """
    Writes a pandas DataFrame to an HDF5 file as a NetCDF dataset inside a specified group.

    Args:
        filename: The path to the HDF5 file to write to.
        group_path: The path of the group in the HDF5 file to store the data.
        value: The pandas DataFrame to write to the HDF5 file.
    """
    with h5py.File(filename, 'a') as f:
        try:
            del f[group_path]
        except Exception:
            pass
    data_set = value.to_xarray()
    data_set = data_set.assign_coords({'index': [i for i in range(len(data_set.index))]})
    data_set.to_netcdf(filename, group=group_path, mode='a')


def read_params_recursive(filename: str | Path, group_name: str, param_class: param.Parameterized):
    """Recursively reads param classes from an H5 file.

    Args:
        filename: file name
        group_name: h5 group
        param_class: param class
    """
    grp_name = f'{group_name}info'
    info = xr.load_dataset(filename, group=grp_name)

    param_names, params = param_util.names_and_params_from_class(param_class)
    if info.attrs is not None:
        # get all params in this class
        supported_types = [param.FileSelector, param.Number, param.ObjectSelector, param.parameterized.String]
        # do this twice so that any watcher that may change the value of another param
        # won't be fired the second time and the other param value won't change
        for _ in range(2):
            for i in range(len(param_names)):
                param_name = param_names[i]
                if param_name == 'name':
                    continue
                value = getattr(param_class, param_name)
                # for some reason, call getattr for a param.DataFrame messes up the data frame
                if type(value) is pd.DataFrame:
                    grp_name = f'{group_name}{param_name}'
                    data_frame = read_data_frame(filename, grp_name)
                    if data_frame is not None:
                        setattr(param_class, param_name, data_frame)
                    continue
                elif param_name not in info.attrs:
                    continue

                obj_type = type(params[i])
                if obj_type in supported_types:
                    setattr(param_class, param_name, info.attrs[param_name])
                elif obj_type == param.Integer:
                    setattr(param_class, param_name, int(info.attrs[param_name]))
                elif obj_type == param.Boolean:
                    setattr(param_class, param_name, True if info.attrs[param_name] != 0 else False)
                else:
                    raise NotImplementedError

    for i in range(len(params)):
        param_name = param_names[i]
        if type(params[i]) is param.ClassSelector:
            val = getattr(param_class, param_name)
            if issubclass(type(val), param.Parameterized):
                nested_group_path = f'{group_name}{param_name}/'
                read_params_recursive(filename, nested_group_path, val)


def read_data_frame(filename: str | Path, group_path: str) -> pd.DataFrame | None:
    """
    Reads a data frame from a specified H5 file and group.

    Args:
        filename (str | Path): The path to the file containing the dataset.
        group_path (str): The name of the group in the dataset to be read.

    Returns:
        pd.DataFrame | None: A pandas DataFrame generated from the dataset if successful,
        or None if the file or dataset/group is invalid.
    """
    try:
        xa = xr.load_dataset(str(filename), group=group_path)
        if xa is not None:
            df = xa.to_dataframe()
            return df
    except OSError:
        pass
    return None


def read_dataclass_recursive(filename: str, group_name: str, dataclass_obj: Any):
    """
    Read a dataclass values from an HDF5 file.

    Args:
        filename: Path to the HDF5 file.
        group_name: The group name in the HDF5 file to read data from.
        dataclass_obj: An instance of a dataclass to populate.

    Returns:
        The populated dataclass object.
    """
    # Load the HDF5 group using xarray
    grp_name = f"{group_name}info"
    info = xr.load_dataset(filename, group=grp_name)

    # Iterate over dataclass fields
    for field in fields(dataclass_obj):
        field_name = field.name
        field_type = field.type

        if field_name == 'name':
            continue
        value = getattr(dataclass_obj, field_name)
        # for some reason, call getattr for a param.DataFrame messes up the data frame
        if type(value) is pd.DataFrame:
            grp_name = f'{group_name}{field_name}'
            try:
                xa = xr.load_dataset(filename, group=grp_name)
                if xa is not None:
                    setattr(dataclass_obj, field_name, xa.to_dataframe())
            except OSError:
                pass
            continue
        elif field_name not in info.attrs:
            continue

        # Check if the field is an attribute in the HDF5 file
        if field_name not in info.attrs:
            continue  # Skip if the attribute is not in the file

        # Handle different field types
        if field_type == int:
            setattr(dataclass_obj, field_name, int(info.attrs[field_name]))
        elif field_type == float:
            setattr(dataclass_obj, field_name, float(info.attrs[field_name]))
        elif field_type == str:
            setattr(dataclass_obj, field_name, str(info.attrs[field_name]))
        elif hasattr(field_type, "__dataclass_fields__"):  # Nested dataclass
            nested_obj = field_type()  # Create a new instance of the nested dataclass
            nested_grp_path = f"{group_name}{field_name}/"
            read_dataclass_recursive(filename, nested_grp_path, nested_obj)
            setattr(dataclass_obj, field_name, nested_obj)
        else:
            raise NotImplementedError(f"Unsupported type: {field_type}")

    return dataclass_obj


def read_dataclass_from_h5_file(filename: str, file_type: str, dataclass_obj: Any):
    """
    Recursively writes dataclass fields to an HDF5 file.

    Args:
        filename: Path to the HDF5 file.
        file_type (str): The file type attribute for netcdf.
        dataclass_obj: The dataclass object to write.
    """
    # check file type and version
    grp_name = 'info'
    info = xr.load_dataset(filename, group=grp_name)
    if 'FILE_TYPE' not in info.attrs:
        raise IOError(f'File attributes do not include file_type in file : {filename}')
    ftype = info.attrs['FILE_TYPE']
    if ftype != file_type:
        raise IOError(f'Unrecognized file_type "{ftype}" attribute in file: {filename}')

    read_dataclass_recursive(filename, group_path='/', dataclass_obj=dataclass_obj)


def write_dataclass_recursive(filename: str, group_name: str, dataclass_obj):
    """
    Recursively writes dataclass fields to an HDF5 file.

    Args:
        filename (str): Path to the HDF5 file.
        group_name (str): The group name within the HDF5 file for this dataclass.
        dataclass_obj: The dataclass object to write.

    Raises:
        NotImplementedError: For unsupported field types.
    """
    # Ensure the dataclass object is a valid dataclass.
    if not is_dataclass(dataclass_obj):
        raise ValueError(f"The object provided is not a dataclass: {dataclass_obj}")

    info = xr.Dataset()  # Create a dataset for storing simple attributes.

    # Iterate over fields in the dataclass.
    for field in fields(dataclass_obj):
        field_name = field.name
        value = getattr(dataclass_obj, field_name)

        # Handle None values (skip writing).
        if value is None:
            continue

        # GROUP Handling: Primitive types are stored as attributes.
        if isinstance(value, (int, float, str, bool)):
            info.attrs[field_name] = value

        # GROUP Handling: DataFrame objects are stored as a NetCDF group.
        elif isinstance(value, pd.DataFrame):
            grp_name = f"{group_name}{field_name}"
            write_dataframe_to_h5(filename, grp_name, value)

        # GROUP Handling: Nested dataclass objects are recursively processed.
        elif is_dataclass(value):
            nested_group_name = f"{group_name}{field_name}/"
            write_dataclass_recursive(filename, nested_group_name, value)

        else:
            raise NotImplementedError(f"Unsupported field type: {field_name} ({type(value)})")

    # Write attributes from this dataclass to its group.
    grp_name = f"{group_name}info"
    info.to_netcdf(filename, group=grp_name, mode='a')
