"""Methods for reading and writing param objects to disk."""

__copyright__ = "(C) Copyright Aquaveo 2024"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import pandas as pd
import param
import xarray as xr

# 3. Aquaveo modules

# 4. Local modules
from xms.guipy.param import param_util


def write_to_h5_file(filename, param_cls):
    """Writes this class to an h5 file.

    Args:
        filename (str): file name
        param_cls (param.Parameterized): param class
    """
    # need to create an empty dataset to start writing the file
    empty = xr.Dataset()
    empty.to_netcdf(filename, group='base/info', mode='w')
    _write_params_recursive(filename, group_name='base/', cls=param_cls)


def read_from_h5_file(filename, param_cls):
    """Reads this class from an h5 file.

    Args:
        filename (str): file name
        param_cls (param.Parameterized): param class
    """
    # check file type and version
    grp_name = 'base/info'
    info = xr.load_dataset(filename, group=grp_name)
    if 'file_type' not in info.attrs:
        raise IOError(f'File attributes do not include file_type in file : {filename}')
    ftype = info.attrs['file_type']
    if ftype != param_cls.file_type:
        raise IOError(f'Unrecognized file_type "{ftype}" attribute in file: {filename}')  # noqa: B028

    _read_params_recursive(filename, group_name='base/', cls=param_cls)


def _supported_types():
    """Get a list of the supported param object types."""
    supported_types = [
        param.FileSelector, param.Integer, param.Number, param.ObjectSelector, param.parameterized.String
    ]
    return supported_types


def _write_params_recursive(filename, group_name, cls):
    """Recursively writes param classes to an h5 file.

    Args:
        filename (str): file name
        group_name (str): h5 group
        cls (param.Parameterized): param class
    """
    # create group for this class
    info = xr.Dataset()

    # get all params in this class
    pnames, params = param_util.names_and_params_from_class(cls)
    supported_types = _supported_types()
    for i in range(len(pnames)):
        pname = pnames[i]
        if pname == 'name':
            continue
        value = getattr(cls, pname)
        if value is None:
            continue
        # obj = getattr(cls, member[0])
        # for some reason call getattr for a param.DataFrame messes up the data frame
        if type(value) is pd.DataFrame:
            grp_name = f'{group_name}{pname}'
            value.to_xarray().to_netcdf(filename, group=grp_name, mode='a')
            continue

        obj_type = type(params[i])
        if obj_type in supported_types:
            info.attrs[pname] = value
        elif obj_type is param.Boolean:
            info.attrs[pname] = 1 if value else 0
        elif obj_type is param.ClassSelector:
            val = getattr(cls, pname)
            if issubclass(type(val), param.Parameterized):
                grp_name = f'{group_name}{pname}/'
                _write_params_recursive(filename, grp_name, val)
        else:
            raise NotImplementedError

    if info.attrs:
        grp_name = f'{group_name}info'
        info.to_netcdf(filename, group=grp_name, mode='a')


def _read_params_recursive(filename, group_name, cls):  # noqa: C901
    """Recursively reads param classes from an h5 file.

    Args:
        filename (str): file name
        group_name (str): h5 group
        cls (param.Parameterized): param class
    """
    grp_name = f'{group_name}info'
    info = xr.load_dataset(filename, group=grp_name)

    pnames, params = param_util.names_and_params_from_class(cls)
    if info.attrs is not None:
        # get all params in this class
        supported_types = _supported_types()
        # do this twice so that any watcher that may change a value of another param
        # won't be fired the second time and the other param value won't change
        for _ in range(2):
            for i in range(len(pnames)):
                if pnames[i] == 'name' or pnames[i] not in info.attrs:
                    continue
                value = getattr(cls, pnames[i])
                # for some reason call getattr for a param.DataFrame messes up the data frame
                if type(value) is pd.DataFrame:
                    grp_name = f'{group_name}{pnames[i]}'
                    xa = xr.load_dataset(filename, group=grp_name)
                    if xa is not None:
                        setattr(cls, pnames[i], xa.to_dataframe())
                    continue

                obj_type = type(params[i])
                if obj_type in supported_types:
                    setattr(cls, pnames[i], info.attrs[pnames[i]])
                elif obj_type == param.Boolean:
                    setattr(cls, pnames[i], True if info.attrs[pnames[i]] != 0 else False)
                else:
                    raise NotImplementedError

    for i in range(len(params)):
        if type(params[i]) is param.ClassSelector:
            val = getattr(cls, pnames[i])
            if issubclass(type(val), param.Parameterized):
                grp_name = f'{group_name}{pnames[i]}/'
                _read_params_recursive(filename, grp_name, val)
        elif type(params[i]) is param.DataFrame:
            ds_name = f'{group_name}{pnames[i]}'
            ds = xr.load_dataset(filename, group=ds_name)
            setattr(cls, pnames[i], ds.to_dataframe())
