"""Code to store interface with the SMS spectral grid coverage."""

__copyright__ = '(C) Copyright Aquaveo 2020'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import uuid

# 2. Third party modules
import h5py
import numpy as np

# 3. Aquaveo modules
from xms.api._xmsapi.dmi import DataDumpIOBase
from xms.data_objects.parameters import Coverage, Dataset, FilterLocation, RectilinearGrid
from xms.datasets.dataset_writer import DatasetReader, DatasetWriter

# 4. Local modules
from xms.coverage.windCoverage import AutoNumber


class PLANE_TYPE_enum(AutoNumber):  # noqa: N801
    """Enum for the supported spectral plane types."""
    __order__ = 'FULL_LOCAL_PLANE FULL_GLOBAL_PLANE HALF_PLANE BOTH_PLANE'
    FULL_LOCAL_PLANE = ()
    FULL_GLOBAL_PLANE = ()
    HALF_PLANE = ()
    BOTH_PLANE = ()


class SpectralGrid:
    """Class to represent a spectral grid."""
    def __init__(self, a_reftime=0.0):
        """Construct the spectral grid."""
        self.m_rectGrid = RectilinearGrid()
        self.m_refTime = a_reftime  # reftime for spectral point, can be different from the datasets
        self.m_timeUnits = 'Seconds'  # 'Seconds', 'Minutes', 'Hours', or 'Days'. Units of time passed when adding ts.
        self.m_planeType = PLANE_TYPE_enum.FULL_LOCAL_PLANE
        self._m_spectra = {}  # key is timestep time in seconds as string, value is list of data values for timestep

    def add_timestep(self, time, values):
        """Add a timestep to the spectral dataset.

        Args:
            time (float): The timestep's offset from the reference time
            values (list of float): The spectral dataset location values for the timestep
        """
        self._m_spectra[str(time)] = values  # store time offset as seconds

    def get_dataset(self, temp_file):
        """Write the spectral dataset to an XMDF file.

        Args:
            temp_file (xms.data_objects._data_objects.parameters.FileLocation): Location to write the dataset

        Returns:
            xms.data_objects.parameters.Dataset: View of the exported dataset
        """
        dset = Dataset()
        dset.time_units = self.m_timeUnits
        dset.ref_time = self.m_refTime
        dset.data_basis = 'NODE'
        dset.set_data(temp_file, self._m_spectra)
        return dset


class SpectralCoverage(DataDumpIOBase):
    """Class to store geometry and attributes of an SMS spectral grid coverage."""
    def __init__(self, a_filename=""):
        """Construct the spectral coverage."""
        super().__init__()
        super().SetSelf(self)
        self.__m_cov = None
        self.__m_points = {}  # call GetSpectralGrids to get all spectral data for a node
        self.m_hasRead = False
        self.m_fileName = a_filename

    @property
    def m_cov(self):
        """Get the spectral coverage geometry.

        Returns:
            xms.data_objects.parameters.Coverage: The spectral coverage geometry
        """
        if not self.m_hasRead and self.m_fileName:
            self.ReadDump(self.m_fileName)
        return self.__m_cov

    @m_cov.setter
    def m_cov(self, val):
        """Set the spectral coverage geometry.

        Args:
            val (xms.data_objects.parameters.Coverage): The spectral coverage geometry
        """
        self.__m_cov = val

    @property
    def m_points(self):
        """Get the spectral coverage points.

        Returns:
            dict: Key is point id and value is list of SpectralGrids defined at that point
        """
        if not self.m_hasRead and self.m_fileName:
            self.ReadDump(self.m_fileName)
        return self.__m_points

    @m_points.setter
    def m_points(self, val):
        """Set the spectral coverage point attributes.

        Args:
            val (dict): Key is point id and value is list of SpectralGrids defined
                at that point
        """
        self.__m_points = val

    def GetSpectralGrids(self, point_id):  # noqa: N802
        """Get spectral grid definitions for a point.

        Args:
            point_id (int): Feature object id of the spectral point

        Returns:
            list: The spectral grids defined at this point
        """
        if not self.m_hasRead and self.m_fileName:
            self.ReadDump(self.m_fileName)
        if point_id in self.m_points:
            return self.m_points[point_id]
        return []

    def AddSpectralGrid(self, point_id, spec_grid):  # noqa: N802
        """Add a spectral grid definition.

        Args:
            point_id (int): Feature object id of the spectral point
            spec_grid (SpectralGrid): The spectral grid to add
        """
        self.m_points.setdefault(point_id, []).append(spec_grid)

    def ReadDump(self, filename):  # noqa: N802,C901
        """Populate from an H5 file written by XMS.

        Args:
            filename (str): Path to the dump file to read
        """
        self.m_hasRead = True
        f = h5py.File(filename, 'r')
        cov_names = f['Map Data'].keys()
        if not cov_names:
            f.close()
            return

        cov_name = list(cov_names)[0]  # Only one coverage per dumpfile
        spec_path = '/Map Data/' + cov_name + '/SpectralData/'

        # Ensure coverage contains nodes
        if spec_path + 'Nodes' not in f:
            f.close()
            return

        # Get parallel arrays of grids and datasets to points. Should only be one grid per point and one dataset per
        # grid. Used to allow multiple grids per point and wrote each timestep to a separate dataset.
        node_ids = f[spec_path + 'Nodes'][:]
        ref_times = f[spec_path + 'NodeRefTime'][:]  # one per unique node id, can be different than reftime on datasets
        grid_uuids = f[spec_path + 'GridUuid'][:].astype(str)

        # build data_objects for each grid and dataset
        rect_grids = {}
        dsets = {}
        skip_dsets = ('DsetUuid', 'GridUuid', 'Nodes', 'NodeRefTime')
        for grid_name in f[spec_path].keys():
            if grid_name in skip_dsets:
                continue
            grid_path = spec_path + grid_name + '/'
            grid_uuid = f[grid_path + 'PROPERTIES/GUID'][0].decode('UTF-8')
            plane_type_path = f'{grid_path}PROPERTIES/ReadAsSpecGrid'
            plane_type = PLANE_TYPE_enum(int(f[plane_type_path][0]))
            grid = RectilinearGrid(filename, grid_path)
            grid.uuid = grid_uuid
            _ = grid.origin  # Load the grid into memory. The dump file is deleted as soon as the coverage loads.
            rect_grids[grid_uuid] = (grid, plane_type)
            dset_names = list(f[grid_path + 'Datasets'].keys())
            for dset_name in dset_names:
                dset_path = grid_path + 'Datasets/' + dset_name + '/'
                if dset_path + 'PROPERTIES/GUID' not in f:
                    continue  # Not a dataset
                reader = DatasetReader(h5_filename=filename, group_path=dset_path)
                dset_list = dsets.setdefault(grid_uuid, [])
                dset_list.append(reader)

        # map grids and datasets to point location
        for idx, node_id in enumerate(node_ids):
            spec_grid = SpectralGrid()
            grid_uuid = grid_uuids[idx]
            spec_grid.m_rectGrid, spec_grid.m_planeType = rect_grids[grid_uuid]
            if len(ref_times) - 1 < idx:
                spec_grid.m_refTime = ref_times[len(ref_times) - 1]
            else:
                spec_grid.m_refTime = ref_times[idx]
            reader = dsets[grid_uuid][0]
            for ts_idx in range(reader.num_times):
                # Get the timestep offset in seconds.
                offset = reader.timestep_offset(ts_idx)
                spec_grid.add_timestep(offset.total_seconds(), values=reader.values[ts_idx].tolist())
            self.AddSpectralGrid(node_id, spec_grid)

        # read geometry
        f.close()
        self.m_cov = Coverage(filename, '/Map Data/' + cov_name)  # dump file deleted as soon as coverage loaded
        self.m_cov.get_points(FilterLocation.LOC_NONE)  # force geometry to load from H5

    def WriteDump(self, filename):  # noqa: N802
        """Write a spectral coverage H5 file that XMS can read.

        Args:
            filename (str): Path to the output file
        """
        # This is fast (at least to write the dump), but it would require changes to the SMS file format and reader.
        self.m_cov.write_h5(filename)
        h5file = h5py.File(filename, 'a')
        cov_group = h5file['Map Data/Coverage1']
        cov_group.create_group('SpectralData')
        spec_path = 'Map Data/Coverage1/SpectralData/'
        # These should all be parallel arrays now. Old format allowed for multiple grids on a point and wrote each
        # timestep as a separate dataset.
        node_ids = []
        node_reftimes = []
        grid_uuids = []
        grid_count = 1
        for pt_id, grids in sorted(self.m_points.items()):  # loop through the coverage points
            if grids:  # loop through the grids of this point, should only be one now.
                node_reftimes.append(grids[0].m_refTime)
                for grid in grids:
                    # This is a bug in data_objects. It won't return the `name` attribute until the grid has been
                    # flushed to disk. Shouldn't be using this old garbage anyway.
                    grid_name = f'Spectral Grid ({grid_count})'
                    grid.m_rectGrid.name = grid_name
                    grid_count += 1
                    grid_uuid = str(uuid.uuid4())
                    grid.m_rectGrid.uuid = grid_uuid
                    # Write the grid to the H5 file.
                    h5file.close()
                    grid.m_rectGrid.write_h5(filename, spec_path, False)
                    h5file = h5py.File(filename, 'a')  # Call above closes the H5 file
                    prop_path = spec_path + grid_name + '/PROPERTIES'
                    if prop_path not in h5file:
                        h5file.create_group(prop_path)
                    read_as_prop = h5file[prop_path].create_dataset('ReadAsSpecGrid', (1, ), dtype='i')
                    read_as_prop[0] = grid.m_planeType.value
                    node_ids.append(pt_id)
                    grid_uuids.append(grid_uuid)

                    multi_dset_path = f'{spec_path}{grid_name}/Datasets/'
                    geom_path = f'{spec_path}{grid_name}'
                    dset_count = 1
                    dset_uuid = str(uuid.uuid4())
                    dset_name = f'Spec_Dset_({dset_count})'
                    writer = DatasetWriter(
                        h5_filename=filename,
                        name=dset_name,
                        dset_uuid=dset_uuid,
                        geom_uuid=grid_uuid,
                        ref_time=float(grid.m_refTime),
                        time_units=grid.m_timeUnits,
                        overwrite=False,
                        geom_path=geom_path,
                        h5_handle=h5file
                    )

                    # Write all timesteps to a single dataset
                    times = np.array([float(ts) for ts in grid._m_spectra.keys()])
                    values = np.array([ts_vals for ts_vals in grid._m_spectra.values()])
                    writer.write_xmdf_dataset(times, values)
                    grid._m_spectra = {}

                    # Write this junk because SMS needs it for some reason
                    dset_path = multi_dset_path + dset_name
                    if dset_path in h5file:
                        dset_group = h5file[dset_path]
                    else:
                        dset_group = h5file.create_group(dset_path)
                    if 'PROPERTIES' in dset_group:
                        prop_group = dset_group['PROPERTIES']
                    else:
                        prop_group = dset_group.create_group('PROPERTIES')
                    if 'Object Type' not in prop_group:
                        type_prop = prop_group.create_dataset('Object Type', (1, ), dtype='S14')
                        type_prop[0] = 'GRID2DDATAOBJ'

        # write mapping between grids/datasets and the coverage points
        spec_group = h5file[spec_path]
        spec_group.create_dataset('GridUuid', (len(grid_uuids), ), dtype='S37', data=grid_uuids)
        spec_group.create_dataset('Nodes', (len(node_ids), ), dtype='i', data=node_ids)
        spec_group.create_dataset('NodeRefTime', (len(node_reftimes), ), dtype='float64', data=node_reftimes)
        h5file.close()

    def Copy(self):  # noqa: N802
        """Return a reference to this object."""
        return self

    def GetDumpType(self):  # noqa: N802
        """Get the XMS coverage dump type."""
        return 'xms.coverage.spectral'


def ReadDumpWithObject(filename):  # noqa: N802
    """Read a spectral coverage dump file.

    Args:
        filename (str): Filepath to the dumped coverage to read

    Returns:
        SpectralCoverage: The loaded spectral coverage
    """
    spec_dump = SpectralCoverage(filename)
    return spec_dump
