"""A writer for spectral coverages for CMS-Wave."""

# 1. Standard Python modules
from collections import OrderedDict

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.core.filesystem import filesystem as io_util
from xms.coverage.spectral import PLANE_TYPE_enum
from xms.data_objects.parameters import Dataset, FilterLocation, julian_to_datetime
from xms.guipy.time_format import string_to_datetime

# 4. Local modules
from xms.cmswave.data import cmswave_consts as const
from xms.cmswave.data.simulation_data import SimulationData
from xms.cmswave.dmi.xms_data import XmsData
import xms.cmswave.file_io.eng_writer


class SpectralCoverageWriter:
    """A class for writing spectral coverages."""
    def __init__(self, query=None):
        """Constructor that does nothing."""
        self._data = None
        self._sim_name = ''
        self._query = query

    @property
    def query(self):
        """Create the Query the first time we need it - This is only for non-recorded tests."""
        if self._query is None:
            self._query = Query()
        return self._query

    def _write_all(self, do_ugrid, times, wind_dirs, wind_mags, surges, using_wind_d_set):
        """Writes all a simulation's spectral coverage to file.

        Args:
            do_ugrid (:obj:`data_objects.Parameters.Spatial.SpatialVector.RectilinearGrid`): The simulation grid
            times (:obj:`list[float]`): list of times in julian double format
            wind_dirs (:obj:`list[float]`): list of wind directions
            wind_mags (:obj:`list[float]`): list of wind speeds
            surges (:obj:`list[float]`): list of water levels
            using_wind_d_set (:obj:`bool`): True to write wind data to file, False to not
        """
        # get the case data
        cases = []
        for time, wind_dir, wind_mag, w_lvl in zip(times, wind_dirs, wind_mags, surges):
            case = xms.cmswave.file_io.eng_writer.CMSWAVECase(time, wind_dir, wind_mag, w_lvl)
            cases.append(case)

        # computational grid definition
        num_freqs = self._data.info.attrs['num_frequencies']
        delta_freq = self._data.info.attrs['delta_frequency']
        min_freq = self._data.info.attrs['min_frequency']
        max_freq = ((num_freqs - 1) * delta_freq) + min_freq

        # build the simulation grid definition
        dummy_d_set = Dataset()
        sim_grid = read_grid_from_file(do_ugrid.cogrid_file)
        sim_plane_type_sms = xms.cmswave.file_io.eng_writer.PlaneTypes.HALF  # CMS-Wave only supports half plane spectra
        global_params = xms.cmswave.file_io.eng_writer.SpectralParams(
            sim_plane_type_sms, sim_grid, dummy_d_set, 0, [], [], do_ugrid.projection
        )
        global_params.set_freqs_const(min_freq, max_freq, delta_freq)
        global_params.set_angles_const(-85, 85, 5)
        if self._data.info.attrs['boundary_source'] == 'Wind Only':
            self.zero_spectrum = True
        else:
            self.zero_spectrum = self._data.info.attrs['side1'] == 'Zero spectrum'
        self._write_single(
            self._data.info.attrs['spectral_uuid'], f'{self._sim_name}-side1', global_params, cases, using_wind_d_set,
            False
        )
        if self._data.info.attrs['plane'] == const.CBX_TEXT_FULL_PLANE_REVERSE:
            self.zero_spectrum = self._data.info.attrs['side3'] == 'Zero spectrum'
            self._write_single(
                self._data.info.attrs['spectral2_uuid'], f'{self._sim_name}-side3', global_params, cases,
                using_wind_d_set, True
            )

    def _write_single(self, cov_uuid, basename, global_params, cases, using_wind_d_set, side_three):
        """Writes a single spectral coverage to file.

        Args:
            cov_uuid (:obj:`str`): UUID of the spectral coverage
            basename (:obj:`str`): Base name of the output file (without extension)
            global_params (:obj:`SpectralParams`): The simulation spectral parameters
            cases (:obj:`list[CMSWAVECase]`): The case data for the simulation
            using_wind_d_set (:obj:`bool`): True to write wind data to file, False to not.
            side_three (:obj:`bool`): True if exporting spectral coverage for side 3, False for side 1
        """
        spec_cov = self.query.item_with_uuid(cov_uuid, generic_coverage=True)
        pt_id_map = {}
        pt_param_map = {}
        spec_pts = spec_cov.m_cov.get_points(FilterLocation.PT_LOC_DISJOINT) if spec_cov else []
        for spec_pt in spec_pts:
            spec_pt_id = spec_pt.id
            pt_id_map[spec_pt_id] = spec_pt
            pt_param_map[spec_pt.id] = []
            unsorted_dict = {}
            spec_grids = spec_cov.GetSpectralGrids(spec_pt_id)
            for spec_grid in spec_grids:
                # only do this once so we only have one file per dataset
                spec_dset = spec_grid.get_dataset(io_util.temp_filename())
                for i in range(spec_dset.num_times):
                    plane_type = xms.cmswave.file_io.eng_writer.PlaneTypes.LOCAL
                    if spec_grid.m_planeType.value == PLANE_TYPE_enum.FULL_GLOBAL_PLANE.value:
                        plane_type = xms.cmswave.file_io.eng_writer.PlaneTypes.GLOBAL
                    elif spec_grid.m_planeType.value == PLANE_TYPE_enum.HALF_PLANE.value:
                        plane_type = xms.cmswave.file_io.eng_writer.PlaneTypes.HALF
                    spec_dset.ts_idx = i
                    ts_time = julian_to_datetime(spec_dset.ts_time)
                    spec_params = xms.cmswave.file_io.eng_writer.SpectralParams(
                        plane_type, spec_grid.m_rectGrid, spec_dset, i, [], [], spec_grid.m_rectGrid.projection, ts_time
                    )
                    if side_three:
                        if spec_params.grid_angle < 180.0:
                            spec_params.grid_angle += 180.0
                        else:
                            spec_params.grid_angle -= 180.0
                    unsorted_dict[round(spec_dset.ts_time, 7)] = spec_params
                    # ^use Julian double representation of time as key
                sorted_dict = OrderedDict(sorted(unsorted_dict.items()))
                for _, v in sorted_dict.items():
                    pt_param_map[spec_pt.id].append(v)

        reftime = string_to_datetime(self._data.info.attrs['reftime'])
        time_units = self._data.info.attrs['reftime_units']
        writer = xms.cmswave.file_io.eng_writer.EngWriter(
            basename, pt_id_map, pt_param_map, global_params, cases, using_wind_d_set, reftime, time_units,
            self._data.info.attrs['angle_convention'], self._data.info.attrs['date_format']
        )
        writer.zero_spectrum = self.zero_spectrum
        writer.write()

    def export_coverages(self):
        """Write the coverage to the file."""
        xms_data = XmsData(self.query)
        # Get the simulation tree item
        sim_uuid = self.query.current_item_uuid()
        sim_item = tree_util.find_tree_node_by_uuid(self.query.project_tree, sim_uuid)
        self._sim_name = sim_item.name
        # Get the simulation's hidden component data.
        sim_comp = self.query.item_with_uuid(sim_uuid, model_name='CMS-Wave', unique_name='Sim_Component')
        self._data = SimulationData(sim_comp.main_file)

        # Check if we need to write the .eng file
        if self._data.info.attrs['boundary_source'] == const.TEXT_NONE:
            return

        # get case times for simulation
        times = self._data.case_times['Time']

        # get the simulation grid
        do_grid = xms_data.do_ugrid

        # get the wind data for the cases
        terms = self._data.info.attrs['source_terms']
        using_wind_d_set = False
        if terms == 'Propagation only':  # propagation only
            wind_dirs = [0.0 for _ in range(len(times))]
            wind_mags = [0.0 for _ in range(len(times))]
        else:
            use_const_wind = self._data.info.attrs['wind']
            if use_const_wind == const.TEXT_CONSTANT:  # constant wind values
                wind_dirs = self._data.case_times['Wind Direction']
                wind_mags = self._data.case_times['Wind Magnitude']
            else:  # write wind dataset to file
                using_wind_d_set = True
                wind_dirs = [0.0 for _ in range(len(times))]
                wind_mags = [0.0 for _ in range(len(times))]

        # get the water levels for the cases
        use_const_tidal = self._data.info.attrs['surge']
        if use_const_tidal == "Constant value":  # const tidal surge values
            water_lvl = self._data.case_times['Water Level']
        else:  # get tidal data from dataset
            water_lvl = [0.0 for _ in range(len(times))]

        # write the spectral energy file
        self._write_all(do_grid, times, wind_dirs, wind_mags, water_lvl, using_wind_d_set)
