"""Class to write a WaveWatch3 spec.list spectral file."""

__copyright__ = "(C) Copyright Aquaveo 2021"
__license__ = "All rights reserved"

# 1. Standard Python modules
import datetime
from io import StringIO
import logging
import shutil

# 2. Third party modules
import netCDF4
import numpy as np

# 3. Aquaveo modules
from xms.core.filesystem import filesystem as io_util
from xms.coverage.spectral import PLANE_TYPE_enum
from xms.data_objects.parameters import FilterLocation, julian_to_datetime

# 4. Local modules


class WW3SpecListWriter:
    """Class to write a WaveWatch3 spec.list spectral file."""
    def __init__(self, spectral_coverages):
        """Constructor.

        Args:
            spectral_coverages (:obj:`list[xms.coverage.spectral.SpectralCoverage]`): The spectral coverages.
        """
        self._ss = StringIO()
        self._logger = logging.getLogger('xms.wavewatch3')
        self._spectral_coverages = spectral_coverages

    def _write_spec_list_file(self):
        """Writes the spec.list file."""
        file_w_path = "spec.list"
        if not self._check_plane_types():
            return
        self._write_spectral_data()
        self._flush(file_w_path)

    def _check_plane_types(self):
        """Checks for spectral grid plane types.  Currently, only global is supported."""
        for spectral_cov in self._spectral_coverages:
            spec_pts = self._get_spectral_points(spectral_cov)
            for spec_pt in spec_pts:
                spec_pt_id = spec_pt.id
                spec_grids = spectral_cov.GetSpectralGrids(spec_pt_id)
                for spec_grid in spec_grids:
                    # Check plane type
                    if spec_grid.m_planeType.value != PLANE_TYPE_enum.FULL_GLOBAL_PLANE.value:
                        self._logger.error("Spectral grid found that is not a global grid.  Cannot write spec.list.")
                        return False
        return True

    def _write_spectral_data(self):
        """Writes the spectral data to a WW3 netCDF4 file."""
        for spectral_cov in self._spectral_coverages:
            spectral_point_list = self._get_spectral_points(spectral_coverage=spectral_cov)
            for index, spectral_point in enumerate(spectral_point_list):
                # Write the name of the netCDF4 file to write based on the spectral coverage name
                if len(spectral_point_list) > 1:
                    # TODO: We need a spectral coverage file with more than one point to test this
                    filename = f'{spectral_cov.m_cov.name}_{index}.nc'  # pragma no cover
                else:
                    filename = f'{spectral_cov.m_cov.name}.nc'
                self._ss.write(f"{filename}\n")
                # Make a new netCDF4 dataset
                root_grp = netCDF4.Dataset(filename=filename, mode='w', format='NETCDF4')
                root_grp.description = 'Created by xmswavewatch3'
                # Get the time steps (in days) used in the spectral datasets
                spectral_times = self._get_spectral_coverage_times(spectral_cov, spectral_point)
                spectral_directions = self.get_spectral_directions(spectral_cov, spectral_point)
                spectral_frequencies = self.get_spectral_frequencies(spectral_cov, spectral_point)
                latitudes, longitudes, station_names = self._get_spectral_point_info(spectral_cov, spectral_point)
                spectral_efth = self._get_efth_values_from_grid(
                    spectral_coverage=spectral_cov,
                    time_values=spectral_times,
                    spec_pt=spectral_point,
                    frequencies=spectral_frequencies,
                    directions=spectral_directions
                )

                # Make some dimensions with the appropriate sizes
                _ = root_grp.createDimension('time', len(spectral_times))
                _ = root_grp.createDimension('station', 1)
                _ = root_grp.createDimension('frequency', len(spectral_frequencies))
                _ = root_grp.createDimension('direction', len(spectral_directions))
                _ = root_grp.createDimension('string16', 16)

                # Make some variables to store the spectral data
                time_var = root_grp.createVariable('time', 'f8', ('time', ))
                time_var.units = 'days since 1990-01-01 00:00:00'
                time_var.calendar = 'gregorian'
                frequency_var = root_grp.createVariable('frequency', 'f8', ('frequency', ))
                direction_var = root_grp.createVariable('direction', 'f8', ('direction', ))
                efth_var = root_grp.createVariable(
                    'efth', 'f8', (
                        'time',
                        'station',
                        'frequency',
                        'direction',
                    ), fill_value=-999.99
                )

                station_name_var = root_grp.createVariable('station_name', 'S1', ('string16', ))
                longitude_var = root_grp.createVariable('longitude', 'f8')
                latitude_var = root_grp.createVariable('latitude', 'f8')
                x_var = root_grp.createVariable('x', 'f8')
                y_var = root_grp.createVariable('y', 'f8')

                # Store the spectral data in the variables
                time_var[:] = np.array(spectral_times)
                frequency_var[:] = np.array(spectral_frequencies)
                direction_var[:] = np.array(spectral_directions)
                efth_var[:] = np.array(spectral_efth)
                if len(station_names) > 0:
                    station_name_var[:] = netCDF4.stringtochar(np.array(station_names, dtype='S16'))
                latitude_var[:] = np.array(latitudes)
                longitude_var[:] = np.array(longitudes)
                x_var[:] = np.array(latitudes)
                y_var[:] = np.array(longitudes)

                # Save the nc
                root_grp.close()

    def _get_spectral_coverage_times(self, spectral_coverage, spec_pt, ref_time=None):
        """Gets the spectral coverage times as days since the ref time.

        Args:
            spectral_coverage (:obj:`xms.coverage.spectral.SpectralCoverage`): The spectral coverage.
            spec_pt: The spectral point.
            ref_time (:obj:`datetime.datetime`): The reference time to compute on.

        Returns:
            (:obj:`list[float]`):  List of spectral coverage times as days since the reference time.
        """
        ref_time = ref_time if ref_time is not None else datetime.datetime(1990, 1, 1)
        spectra_times = []
        spec_pt_id = spec_pt.id
        spec_grids = spectral_coverage.GetSpectralGrids(spec_pt_id)
        for spec_grid in spec_grids:
            # Only do this once so we only have one file per dataset
            spec_dset = spec_grid.get_dataset(io_util.temp_filename())
            for i in range(spec_dset.num_times):
                # Convert the time values read (julian dates) to days since the ref time of 1 Jan 1990
                spec_dset.ts_idx = i
                ts_time = julian_to_datetime(spec_dset.ts_time)
                delta_time_days = (ts_time - ref_time).total_seconds() / (3600.0 * 24.0)
                spectra_times.append(delta_time_days)
        spectra_times = list(set(spectra_times))
        spectra_times.sort()
        return spectra_times

    def _get_spectral_points(self, spectral_coverage):
        """Gets the spectral points in the spectral coverage.

        Args:
            spectral_coverage (:obj:`xms.coverage.spectral.SpectralCoverage`): The spectral coverage.

        Returns:
            (:obj:`list[Point]`):  List of points in the spectral coverage.
        """
        return spectral_coverage.m_cov.get_points(FilterLocation.PT_LOC_DISJOINT)

    def get_spectral_directions(self, spectral_coverage, spec_pt):
        """Gets the spectral directions in the spectral coverage.

        Args:
            spectral_coverage (:obj:`xms.coverage.spectral.SpectralCoverage`): The spectral coverage.
            spec_pt: The spectral point.

        Returns:
            (:obj:`list[float]`):  List of directions (from the grid) in the spectral coverage.
        """
        spec_pt_id = spec_pt.id
        spec_grids = spectral_coverage.GetSpectralGrids(spec_pt_id)
        for spec_grid in spec_grids:
            angle = spec_grid.m_rectGrid.angle + 90.0  # Convert from cartesian to oceanographic
            j_sizes = spec_grid.m_rectGrid.j_sizes
            directions = [0] * len(j_sizes)
            for j in range(len(j_sizes)):
                val = angle if j == 0 else directions[j - 1] - j_sizes[j]
                val = val + 360.0 if val < 0 else val
                directions[j] = val
            return directions

    def get_spectral_frequencies(self, spectral_coverage, spec_pt):
        """Gets the spectral frequencies in the spectral coverage.

        Args:
            spectral_coverage (:obj:`xms.coverage.spectral.SpectralCoverage`): The spectral coverage.
            spec_pt: The spectral point.

        Returns:
            (:obj:`list[float]`):  List of frequencies (from the grid) in the spectral coverage.
        """
        spec_pt_id = spec_pt.id
        spec_grids = spectral_coverage.GetSpectralGrids(spec_pt_id)
        for spec_grid in spec_grids:
            origin = spec_grid.m_rectGrid.origin
            i_sizes = spec_grid.m_rectGrid.i_sizes
            frequencies = [origin.x]
            for i in range(len(i_sizes)):
                val = frequencies[i] + i_sizes[i]
                frequencies.append(val)
            return frequencies

    def _get_spectral_point_info(self, spectral_coverage, spec_pt):
        """Gets the spectral point information the spectral coverage.

        Args:
            spectral_coverage (:obj:`xms.coverage.spectral.SpectralCoverage`): The spectral coverage.

        Returns:
            (:obj:`list[float]`):  List of latitude values of the points.
            (:obj:`list[float]`):  List of longitude values of the points.
            (:obj:`list[str]`):  List of station names of the points.
        """
        latitudes = []
        longitudes = []
        station_names = []
        latitudes.append(spec_pt.y)
        longitudes.append(spec_pt.x)
        station_names.append('Point1')
        return latitudes, longitudes, station_names

    def _get_efth_values_from_grid(self, spectral_coverage, time_values, spec_pt, frequencies, directions):
        """Gets the spectral point information the spectral coverage.

        Args:
            spectral_coverage (:obj:`xms.coverage.spectral.SpectralCoverage`): The spectral coverage.
            spec_pt: The spectral point.
            frequencies (:obj:`list[float]`): The frequencies.
            directions (:obj:`list[float]`): The directions

        Returns:
            (:obj:`list`):
                list of (time, station, frequency, direction) values.
        """
        efth = []
        for time_idx, _ in enumerate(time_values):
            stations = []
            spec_pt_id = spec_pt.id
            spec_grids = spectral_coverage.GetSpectralGrids(spec_pt_id)
            for spec_grid in spec_grids:
                spec_dset = spec_grid.get_dataset(io_util.temp_filename())
                spec_dset.ts_idx = time_idx
                # Read the spectral data for the time step
                cur_data = spec_dset.data
                # Reshape into a 2D array
                cur_data_2d = np.array(cur_data).reshape(len(directions) + 1, len(frequencies))
                # Transpose to get the right dimensions
                transposed_2d = np.transpose(cur_data_2d)
                # Remove the extra element used in the spectral data, but not in the NC
                transposed_2d = np.delete(transposed_2d, -1, 1)
                # Calculate the amount to roll
                roll_count = -int(-90 / abs(directions[0] - directions[1]) + 1)
                for i in range(len(transposed_2d)):
                    # Roll to start at 90
                    transposed_2d[i] = np.roll(transposed_2d[i], roll_count)
                    # Flip the list
                    transposed_2d[i] = np.flip(transposed_2d[i])
                stations.append(transposed_2d.tolist())
            efth.append(stations)
        efth = efth if len(efth) > 0 else [[[[]]]]  # Give it the right default size if necessary
        return efth

    def _flush(self, file_w_path):
        """Writes the StringIO previously processed to a file.

        Args:
            file_w_path (:obj:`str`):  String of the filename to write to.
        """
        f = open(file_w_path, 'w')
        self._ss.seek(0)
        shutil.copyfileobj(self._ss, f, 100000)
        f.close()

    def write(self):
        """Top-level entry point for the WaveWatch3 spec.list file writer."""
        self._write_spec_list_file()
