"""Utility script for converting old SMS spectral H5 files."""

# 1. Standard Python modules
import argparse
import os

# 2. Third party modules
import h5py
import numpy as np

# 3. Aquaveo modules

# 4. Local modules

SUCCESS_EXIT_CODE = 0
OOPSY_EXIT_CODE = -1


def convert_spectral_h5_file(arg_str=None):
    """Convert an old SMS spectral H5 file so it has a single dataset per grid.

    Args:
        arg_str (list): command strings

    Returns:
        (int): return code
    """
    try:
        arguments = argparse.ArgumentParser(description='Converter for old SMS spectral H5 files.')
        arguments.add_argument(dest='command', type=str, help='script to run')
        arguments.add_argument(dest='input_file', type=str, help='Path to old file to convert.')
        arguments.add_argument(dest='output_file', type=str, help='Path to where the converted file will be written.')
        parsed_args = arguments.parse_args(arg_str)

        # Convert the old file so there is a single dataset per grid. Used to have a dataset per timestep.
        converter = SpectralH5Converter(parsed_args.input_file.strip('"'), parsed_args.output_file.strip('"'))
        converter.convert()
        return SUCCESS_EXIT_CODE
    except Exception:  # Give SMS non-zero exit code
        return OOPSY_EXIT_CODE


class SpectralH5Converter:
    """Class to handle conversion of an old SMS spectral H5 file."""
    def __init__(self, input_file, output_file):
        """Constructor.

        Args:
            input_file: The input file path of the spectral data.
            output_file: The output file path for the converted spectral data.
        """
        self.input_file = input_file
        self.output_file = output_file
        self.fin = None
        self.fout = None
        self.first_node = None
        self.first_properties = None
        self.values_stack = None
        self.times_stack = None
        self.stacked_dsets = set()

    def _reset(self):
        """Resets the internal state of the SpectralH5Converter object.

        This method resets the first_node, first_properties, values_stack, times_stack, and stacked_dsets
        attributes to their initial values.
        """
        self.first_node = None
        self.first_properties = None
        self.values_stack = None
        self.times_stack = None
        self.stacked_dsets = set()

    def _create_dataset(self):
        """Creates a dataset in the HDF5 file."""
        if self.first_node is None:
            return
        multi_dsets = self.fout.create_group(self.first_node.parent.name)
        self._copy_attrs(self.first_node.parent, multi_dsets)
        dset_grp = multi_dsets.create_group(self.first_node.name)
        self._copy_attrs(self.first_node, dset_grp)
        prop_grp = dset_grp.create_group('PROPERTIES')
        for name, obj in self.first_properties.items():
            prop_grp.create_dataset(name=name, dtype=obj.dtype, shape=obj.shape, data=obj[...])
        self._copy_attrs(self.first_properties, prop_grp)
        self.times_stack = self.times_stack.flatten()  # ensure this is a 1D array
        # Sort the spectra by timestep time.
        sorted_indices = np.argsort(self.times_stack)
        self.times_stack = self.times_stack[sorted_indices]
        self.values_stack = self.values_stack[sorted_indices]

        dset_grp.create_dataset(
            name='Values', dtype=self.values_stack.dtype, shape=self.values_stack.shape, data=self.values_stack
        )
        dset_grp.create_dataset(
            name='Times', dtype=self.times_stack.dtype, shape=self.times_stack.shape, data=self.times_stack
        )
        dset_grp.create_dataset(name='Maxs', data=np.amax(self.values_stack, axis=1))
        dset_grp.create_dataset(name='Mins', data=np.amin(self.values_stack, axis=1))
        self._reset()

    def _store_dataset(self, node):
        """Store a dataset (timestep) to write later.

        Args:
            node (h5py.Group): The group object to be stored.

        Description:
            This method is used to store a dataset in the spectral HDF5 converter. It takes a h5py.Group object as
            input, which represents the node to be stored. The method performs the following actions:

            - If the first_node attribute is None, it assigns the input node to the first_node attribute.
            - For each item in the input node:
                - If the item is a h5py.Group object, it stores the properties of the first dataset in the
                    first_properties attribute if it is None.
                - If the item is a dataset named 'Values', it stores the values in the values_stack attribute. If
                    values_stack is None, it assigns the values to values_stack. Otherwise, it concatenates the values
                    with the existing values in values_stack.
                - If the item is a dataset named 'Times', it stores the times in the times_stack attribute. If
                    times_stack is None, it assigns the times to times_stack. Otherwise, it concatenates the times with
                    the existing times in times_stack.
        """
        if self.first_node is None:
            self.first_node = node
        for name, obj in node.items():
            if isinstance(obj, h5py.Group):
                # Store of the properties of the first Dataset.
                if self.first_properties is None:
                    self.first_properties = obj
            elif name == 'Values':  # These are the dataset values. Store them.
                self.values_stack = obj[...] if self.values_stack is None else np.vstack((self.values_stack, obj[...]))
            elif name == 'Times':  # These are the times. Store them.
                self.times_stack = obj[...] if self.times_stack is None else np.vstack((self.times_stack, obj[...]))

    def _copy_attrs(self, src_node, dest_node):
        """Copy attributes from source node to destination node.

        Args:
            src_node (h5py.Dataset or h5py.Group): The source node from which attributes will be copied.
            dest_node (h5py.Dataset or h5py.Group): The destination node where attributes will be copied to.
        """
        for name, value in src_node.attrs.items():
            dest_node.attrs[name] = value

    def _check_if_xmdf_node(self, node):
        """Check whether the given node or any of its ancestors are part of a multi-dataset group.

        Args:
            node: An h5py.Group object representing a node in an HDF5 file.

        Returns:
            bool: See description
        """
        gp = node.parent.parent  # gp = grandparent
        ggp = gp.parent if gp else None  # ggp = great grandparent
        parent_is_multi_dset = node.parent.attrs.get('Grouptype', np.array(b'')).item().decode() == 'MULTI DATASETS'
        gp_is_multi_dset = gp and gp.attrs.get('Grouptype', np.array(b'')).item().decode() == 'MULTI DATASETS'
        ggp_is_multi_dset = False
        if ggp:
            ggp_is_multi_dset = gp.parent.attrs.get('Grouptype', np.array(b'')).item().decode() == 'MULTI DATASETS'
        if parent_is_multi_dset or gp_is_multi_dset or ggp_is_multi_dset:
            return True
        return False

    def _walk_file(self, name, node):
        """Callback for recursively traversing an H5 file.

        Args:
            name (str): The name of the current file or dataset.
            node: The current node being processed.
        """
        if os.path.basename(name) == 'Datasets':
            for dset in node.keys():
                self._store_dataset(node[dset])  # This is an XMDF dataset, store it
            self._create_dataset()
            return
        elif self._check_if_xmdf_node(node):
            return
        elif isinstance(node, h5py.Group):
            # If node is a group, copy it
            grp = self.fout.create_group(name)
            self._copy_attrs(src_node=node, dest_node=grp)
        else:  # isinstance(node, h5py.Dataset):
            # If node is a dataset, assign it directly
            dset = self.fout.create_dataset(name=name, dtype=node.dtype, shape=node.shape, data=node[...])
            self._copy_attrs(src_node=node, dest_node=dset)

    def convert(self):
        """Converts a spectral H5 file to a new format."""
        with h5py.File(self.input_file, 'r') as self.fin:
            with h5py.File(self.output_file, 'w') as self.fout:
                self.fin.visititems(self._walk_file)


if __name__ == '__main__':
    conv = SpectralH5Converter('C:/temp/old_spec_spec_grds.h5', 'C:/temp/test_converter.h5')
    conv.convert()
