"""Code to store interface with the SMS spectral grid coverage."""

__copyright__ = '(C) Copyright Aquaveo 2020'
__license__ = 'All rights reserved'

# 1. Standard Python modules
import collections
import datetime

# 2. Third party modules
import h5py

# 3. Aquaveo modules
from xms.api._xmsapi.dmi import DataDumpIOBase
from xms.data_objects.parameters import Coverage, FilterLocation

# 4. Local modules


class SpatialCoverage(DataDumpIOBase):
    """Class to store geometry and attributes of an SMS spatial grid coverage."""
    class TimeSeries:
        """Store time series for spatial data."""
        def __init__(self):
            """Construct TimeSeries class."""
            self.curve_name = None
            self.group_name = None
            self.ref_time = None
            self.series = None

        def write(self, plugin_group):
            """Write time series plugin group to H5 file.

            Args:
                plugin_group(h5py.Group): The time series H5 plugin group.
            """
            plugin_group.create_dataset("PluginType", (1, ), dtype=f'S{len(b"Time Series") + 1}')
            plugin_group["PluginType"][:] = b"Time Series"

            time_series_group = plugin_group.create_group("TS_1")
            time_series_group.attrs.create("Grouptype", b"Generic", dtype="S8", shape=(1, ))

            time_series_group.create_dataset("crv_name", (1, ), dtype=f'S{len(self.curve_name.encode("latin-1")) + 1}')
            time_series_group["crv_name"][:] = self.curve_name.encode("latin-1")

            time_series_group.create_dataset("grp_name", (1, ), dtype=f'S{len(self.group_name.encode("latin-1")) + 1}')
            time_series_group["grp_name"][:] = self.group_name.encode("latin-1")

            time_series_group.create_dataset("row_cnt", (1, ), dtype='i4')
            row_cnt = len(self.series[0])
            time_series_group["row_cnt"][:] = row_cnt

            time_series_group.create_dataset("col_count", (1, ), dtype='i4')
            time_series_group["col_count"][:] = len(self.series)

            if self.ref_time is not None:
                date_string = datetime.datetime.strftime(self.ref_time, "%m/%d/%Y %H:%M:%S")
                time_series_group.create_dataset("ref_time", (1, ), dtype=f'S{len(date_string.encode("latin-1")) + 1}')
                time_series_group["ref_time"][:] = date_string.encode("latin-1")

            for i, data in enumerate(self.series):
                col_name = f"col_{i + 1}"
                time_series_group.create_dataset(col_name, (row_cnt, ), dtype='f8', data=data)

    class CompassData:
        """Store compass data for spatial data."""
        def __init__(self):
            """Initialize CompassData instance."""
            self.name = None
            self.name_show = None
            self.vectors = []
            self.legend_show = None
            self.legend_location = None
            self.legend_min_max_only = None
            self.legend_precision = None
            self.ring_percentages = []
            self.radius = None
            self.direction_only = None
            self.connection_lines = None
            self.background_filled = None
            self.background_color = None
            self.specify_min_max = None
            self.ring_min = None
            self.ring_max = None
            self.arrow_style = None
            self.position = None

        def write(self, compass_data_group):
            """Write compass data plugin group to H5 file.

            Args:
                compass_data_group(h5py.Group): The compass data H5 group.
            """
            compass_data_group.create_dataset("PluginType", (1, ), dtype=f'S{len(b"Compass Plot") + 1}')
            compass_data_group["PluginType"][:] = b"Compass Plot"

            compass_data_group.create_dataset("Name", (1, ), dtype=f'S{len(self.name.encode("latin-1")) + 1}')
            compass_data_group["Name"][:] = self.name.encode("latin-1")

            compass_data_group.create_dataset("NameShow", (1, ), dtype='i4')
            compass_data_group["NameShow"][:] = self.name_show

            for i, vector in enumerate(self.vectors):
                group_name = f"dataset_vector_{i + 1}"
                vector_group = compass_data_group.create_group(group_name)
                vector_group.attrs.create("Grouptype", b"Generic", dtype="S8", shape=(1, ))
                vector.write(vector_group)

            compass_data_group.create_dataset("LegendShow", (1, ), dtype='i4')
            compass_data_group["LegendShow"][:] = self.legend_show

            compass_data_group.create_dataset("LegendLocation", (1, ), dtype='i4')
            compass_data_group["LegendLocation"][:] = self.legend_location

            compass_data_group.create_dataset("LegendMinMaxOnly", (1, ), dtype='i4')
            compass_data_group["LegendMinMaxOnly"][:] = self.legend_min_max_only

            compass_data_group.create_dataset("LegendPrecision", (1, ), dtype='i4')
            compass_data_group["LegendPrecision"][:] = self.legend_precision

            compass_data_group.create_dataset("RingPercentages", (len(self.ring_percentages), ), dtype='i4')
            compass_data_group["RingPercentages"][:] = self.ring_percentages

            compass_data_group.create_dataset("Radius", (1, ), dtype='i4')
            compass_data_group["Radius"][:] = self.radius

            compass_data_group.create_dataset("DirectionOnly", (1, ), dtype='i4')
            compass_data_group["DirectionOnly"][:] = self.direction_only

            compass_data_group.create_dataset("ConnectionLines", (1, ), dtype='i4')
            compass_data_group["ConnectionLines"][:] = self.connection_lines

            compass_data_group.create_dataset("BackgroundFilled", (1, ), dtype='i4')
            compass_data_group["BackgroundFilled"][:] = self.background_filled

            compass_data_group.create_dataset("BackgroundColor", (4, ), dtype='u1')
            compass_data_group["BackgroundColor"][:] = self.background_color

            compass_data_group.create_dataset("SpecifyMinMax", (1, ), dtype='i4')
            compass_data_group["SpecifyMinMax"][:] = self.specify_min_max

            compass_data_group.create_dataset("RingMin", (1, ), dtype='f8')
            compass_data_group["RingMin"][:] = self.ring_min

            compass_data_group.create_dataset("RingMax", (1, ), dtype='f8')
            compass_data_group["RingMax"][:] = self.ring_max

            compass_data_group.create_dataset("ArrowStyle", (1, ), dtype='i4')
            compass_data_group["ArrowStyle"][:] = self.arrow_style

            compass_data_group.create_dataset("Position", (2, ), dtype='i4')
            compass_data_group["Position"][:] = self.position

    class CompassVectorData:
        """Store compass data for spatial data."""
        def __init__(self):
            """Initialize CompassVectorData instance."""
            self.curve_group = None
            self.curve_name = None
            self.enabled = None
            self.color = None

        def write(self, vector_group):
            """Write compass vector data group to H5 file.

            Args:
                vector_group(h5py.Group): The compass vector data H5 group.
            """
            vector_group.create_dataset("Curve Group", (1, ), dtype=f'S{len(self.curve_group.encode("latin-1")) + 1}')
            vector_group["Curve Group"][:] = self.curve_group.encode("latin-1")

            vector_group.create_dataset("Curve Name", (1, ), dtype=f'S{len(self.curve_name.encode("latin-1")) + 1}')
            vector_group["Curve Name"][:] = self.curve_name.encode("latin-1")

            vector_group.create_dataset("Enabled", (1, ), dtype='i4')
            vector_group["Enabled"][:] = self.enabled

            vector_group.create_dataset("Color", (4, ), dtype='u1')
            vector_group["Color"][:] = self.color

    def __init__(self, a_filename=""):
        """Construct the spectral coverage."""
        # super().__init__()
        DataDumpIOBase.__init__(self)
        super().SetSelf(self)
        self._coverage = None
        self.spatial_data = collections.defaultdict(lambda: collections.defaultdict(lambda: None))
        self.has_read = False
        self.file_name = a_filename

    @property
    def coverage(self):
        """Get the spatial coverage geometry.

        Returns:
            xms.data_objects.parameters.Coverage: The spatial coverage geometry
        """
        if not self.has_read and self.file_name:
            self.ReadDump(self.file_name)
        return self._coverage

    @coverage.setter
    def coverage(self, val):
        """Set the spatial coverage geometry.

        Args:
            val (xms.data_objects.parameters.Coverage): The spatial coverage geometry
        """
        self._coverage = val

    def WriteDump(self, filename):  # noqa: N802
        """Write a spatial coverage H5 file that XMS can read.

        Args:
            filename (str): Path to the output file
        """
        # Get the points before dumping the coverage geometry to disk.
        points = self._coverage.get_points(FilterLocation.LOC_ALL)
        self._coverage.write_h5(filename)

        h5_file = h5py.File(filename, "a")
        cov_group = h5_file['Map Data/Coverage1']
        spatial_group = cov_group.create_group("SpatialData")
        spatial_group.attrs.create("Grouptype", b"Generic", dtype="S8", shape=(1, ))

        for point in points:  # loop through the coverage points
            if point.id in self.spatial_data:
                node_group = spatial_group.create_group(f"node_{point.id}")
                node_group.attrs.create("Grouptype", b"Generic", dtype="S8", shape=(1, ))
                for plugin_id, plugin_data in self.spatial_data[point.id].items():
                    plugin_group = node_group.create_group(f"plugin_{plugin_id}")
                    plugin_group.attrs.create("Grouptype", b"Generic", dtype="S8", shape=(1, ))
                    plugin_data.write(plugin_group)

        h5_file.close()

    def ReadDump(self, filename):  # noqa: N802,C901
        """Populate from an H5 file written by XMS.

        Args:
            filename (str): Path to the dump file to read
        """
        self.has_read = True
        f = h5py.File(filename, "r")
        cov_names = f["Map Data"].keys()
        if not cov_names:
            f.close()
            return

        cov_name = None
        for cov in cov_names:
            cov_name = cov
            break
        spatial_path = "/Map Data/" + cov_name + "/SpatialData/"

        # build data_objects for each grid and dataset
        node_names = f[spatial_path].keys()  # node_1, node_2, etc.
        for node_name in node_names:
            node = int(node_name.split("_")[1])
            plugin_names = f[f"{spatial_path}/{node_name}"].keys()
            for plugin_name in plugin_names:
                # plugin_1, plugin_2, etc.
                plugin_id = int(plugin_name.split("_")[1])
                plugin_path = f"{spatial_path}/{node_name}/{plugin_name}"
                plugin_type = f[f"{plugin_path}/PluginType"][0]
                if plugin_type == b"Compass Plot":
                    self._read_compass_plot(f[plugin_path], node, plugin_id)
                elif plugin_type == b"Time Series":
                    self._read_time_series(f[plugin_path], node, plugin_id)

        # read geometry
        f.close()
        self._coverage = Coverage(filename, "/Map Data/" + cov_name)  # dump file deleted as soon as coverage is loaded
        self._coverage.get_points(FilterLocation.LOC_NONE)  # force geometry to load from H5

    def _read_compass_plot(self, compass_plot_group, node, plugin_id):
        """Reads the compass plot group for a spatial data plugin item.

        Args:
            compass_plot_group (h5py.Group): Compass plot group.
            node (int): The coverage node ID.
            plugin_id(int): The plugin ID.
        """
        compass_data = SpatialCoverage.CompassData()
        self.spatial_data[node][plugin_id] = compass_data
        compass_data.name = _h5_to_str(compass_plot_group, "Name")
        compass_data.name_show = compass_plot_group["NameShow"][0]

        subgroup_names = compass_plot_group.keys()
        count = 1
        vector_name = f"dataset_vector_{count}"
        compass_data.vectors = []
        while vector_name in subgroup_names:
            vector_group = compass_plot_group[vector_name]
            vector_data = SpatialCoverage.CompassVectorData()
            vector_data.curve_group = _h5_to_str(vector_group, "Curve Group")
            vector_data.curve_name = _h5_to_str(vector_group, "Curve Name")
            vector_data.enabled = vector_group["Enabled"][0]
            vector_data.color = vector_group["Color"][:].tolist()
            compass_data.vectors.append(vector_data)

            count += 1
            vector_name = f"dataset_vector_{count}"

        compass_data.legend_show = compass_plot_group["LegendShow"][0]
        compass_data.legend_location = compass_plot_group["LegendLocation"][0]
        compass_data.legend_min_max_only = compass_plot_group["LegendMinMaxOnly"][0]
        compass_data.legend_precision = compass_plot_group["LegendPrecision"][0]
        compass_data.ring_percentages = compass_plot_group["RingPercentages"][:].tolist()
        compass_data.radius = compass_plot_group["Radius"][0]
        compass_data.direction_only = compass_plot_group["DirectionOnly"][0]
        compass_data.connection_lines = compass_plot_group["ConnectionLines"][0]
        compass_data.background_filled = compass_plot_group["BackgroundFilled"][0]
        compass_data.background_color = compass_plot_group["BackgroundColor"][:].tolist()
        compass_data.specify_min_max = compass_plot_group["SpecifyMinMax"][0]
        compass_data.ring_min = compass_plot_group["RingMin"][0]
        compass_data.ring_max = compass_plot_group["RingMax"][0]
        compass_data.arrow_style = compass_plot_group["ArrowStyle"][0]
        compass_data.position = compass_plot_group["Position"][:].tolist()

    def _read_time_series(self, plugin_group, node, plugin_id):
        """Reads the time series group for a spatial data plugin item.

        Args:
            plugin_group (h5py.Group): Plugin group containing the time series.
            node (int): The coverage node ID.
            plugin_id(int): The plugin ID.
        """
        time_series = SpatialCoverage.TimeSeries()
        self.spatial_data[node][plugin_id] = time_series
        time_series_group = plugin_group["TS_1"]

        time_series.curve_name = _h5_to_str(time_series_group, "crv_name")
        time_series.group_name = _h5_to_str(time_series_group, "grp_name")

        # row_count = time_series_group["row_cnt"]
        col_count = time_series_group["col_count"][0]

        if "ref_time" in time_series_group.keys():
            ref_time_string = _h5_to_str(time_series_group, "ref_time")
            time_series.ref_time = datetime.datetime.strptime(ref_time_string, "%m/%d/%Y %H:%M:%S")

        time_series.series = []
        for i in range(col_count):
            col_name = f"col_{i + 1}"
            data = time_series_group[col_name]
            time_series.series.append(data[:].tolist())

    def Copy(self):  # noqa: N802
        """Return a reference to this object."""
        return self

    def GetDumpType(self):  # noqa: N802
        """Get the XMS coverage dump type."""
        return "xms.coverage.spatial"


def ReadDumpWithObject(filename):  # noqa: N802
    """Read a spatial coverage dump file.

    Args:
        filename (str): Filepath to the dumped coverage to read

    Returns:
        SpatialCoverage: The loaded spatial coverage
    """
    spec_dump = SpatialCoverage(filename)
    return spec_dump


def _h5_to_str(group, key):
    """Decode an H5 string value into a string.

    Args:
        group (h5py.Group): The H5 group.
        key (str): The H5 dataset name.

    Returns:
        (str): Decoded UTF-8 string.
    """
    return group[key][0].decode("utf-8")
