"""Simulation data class."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import datetime
from importlib.metadata import version
import os

# 2. Third party modules
import numpy as np
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.guipy.time_format import datetime_to_string

# 4. Local modules
from xms.cmswave.data import cmswave_consts as const


def case_data_table(times, wind_dir, wind_mag, water_level):
    """Creates case time data.

    Args:
        times (:obj:`list`): The simulation times
        wind_dir (:obj:`list`): The case wind directions
        wind_mag (:obj:`list`): The case wind magnitudes
        water_level (:obj:`list`): The case water levels
    """
    return {
        'Time': xr.DataArray(data=times),
        'Wind Direction': xr.DataArray(data=np.array(wind_dir, dtype=np.float64)),
        'Wind Magnitude': xr.DataArray(data=np.array(wind_mag, dtype=np.float64)),
        'Water Level': xr.DataArray(data=np.array(water_level, dtype=np.float64)),
    }


def convert_time_into_seconds(time_units, time_val):
    """Converts the time value by the time units, into seconds.

    Time units should be .

    Args:
        time_units (:obj:`str`): The simulation time units, one of: 'days', 'hours', or 'minutes'
        time_val (:obj:`float`): The time value to convert
    """
    conversion_factor = 1.0
    if time_units == const.TIME_UNITS_MINUTES:
        conversion_factor = 60.0
    elif time_units == const.TIME_UNITS_HOURS:
        conversion_factor = 3600.0
    elif time_units == const.TIME_UNITS_DAYS:
        conversion_factor = 3600.0 * 24.0
    return time_val * conversion_factor


class SimulationData(XarrayBase):
    """Manages data file for the hidden simulation component."""
    def __init__(self, filename):
        """
        Initializes the data class.

        Args:
            filename (:obj:`str`): The name of the main file that data is stored in.
        """
        super().__init__(filename.strip('"\''))
        self._case_times = None
        self.info.attrs['FILE_TYPE'] = 'CMSWAVE_SIMULATION'
        self._load_parameter_default_values()
        self._load_boundary_default_values()
        self._load_output_default_values()
        self._load_options_default_values()
        self._load_take_values()
        self.load_all()

    @property
    def case_times(self):
        """
        Get the case times dataset.

        Returns:
            (:obj:`xarray.Dataset`): The case_times list dataset.
        """
        if self._case_times is None:
            self._case_times = self.get_dataset('case_times', False)
            if self._case_times is None:
                self._case_times = self._default_case_times()
        return self._case_times

    @case_times.setter
    def case_times(self, value):
        """
        Sets the case time data.

        Args:
            value: The case time data.
        """
        self._case_times = value

    @staticmethod
    def _default_case_times():
        """
        Creates a default case times data set.

        Returns:
            (:obj:`xarray.Dataset`): The case times dataset.
        """
        case_time_data = case_data_table([], [], [], [])
        return xr.Dataset(data_vars=case_time_data)

    def _load_parameter_default_values(self):
        """Loads all of the default values that are not a dataset and not in the info attributes already."""
        if 'reftime_units' not in self.info.attrs:
            self.info.attrs['reftime_units'] = 'hours'
        if 'reftime' not in self.info.attrs:
            reftime = datetime.datetime(year=1950, month=1, day=1)
            self.info.attrs['reftime'] = datetime_to_string(reftime)
        if 'plane' not in self.info.attrs:
            self.info.attrs['plane'] = 'Half plane'
        if 'source_terms' not in self.info.attrs:
            self.info.attrs['source_terms'] = 'Source terms and propagation'

        if 'current_interaction' not in self.info.attrs:
            self.info.attrs['current_interaction'] = const.TEXT_NONE
        if 'current_uuid' not in self.info.attrs:
            self.info.attrs['current_uuid'] = ''

        if 'matrix_solver' not in self.info.attrs:
            self.info.attrs['matrix_solver'] = 'Gauss-Seidel'
        if 'num_threads' not in self.info.attrs:
            self.info.attrs['num_threads'] = 1
        self._load_parameter_bottom_friction_default_values()
        self._load_parameter_fields_default_values()

    def _load_parameter_bottom_friction_default_values(self):
        """Loads all of the default values that are not a dataset and not in the info attributes already."""
        if 'friction' not in self.info.attrs:
            self.info.attrs['friction'] = const.TEXT_NONE
        if 'darcy' not in self.info.attrs:
            self.info.attrs['darcy'] = 0.005
        if 'darcy_uuid' not in self.info.attrs:
            self.info.attrs['darcy_uuid'] = ''
        if 'manning' not in self.info.attrs:
            self.info.attrs['manning'] = 0.025
        if 'manning_uuid' not in self.info.attrs:
            self.info.attrs['manning_uuid'] = ''

    def _load_parameter_fields_default_values(self):
        """Loads all of the default values that are not a dataset and not in the info attributes already."""
        if 'surge' not in self.info.attrs:
            self.info.attrs['surge'] = const.TEXT_NONE
        if 'surge_uuid' not in self.info.attrs:
            self.info.attrs['surge_uuid'] = ''

        if 'wind' not in self.info.attrs:
            self.info.attrs['wind'] = const.TEXT_NONE
        if 'limit_wave_inflation' not in self.info.attrs:
            self.info.attrs['limit_wave_inflation'] = 0
        if 'wind_uuid' not in self.info.attrs:
            self.info.attrs['wind_uuid'] = ''

    def _load_boundary_default_values(self):
        """Loads all of the boundary default values that are not in the info attributes already."""
        if 'boundary_source' not in self.info.attrs:
            self.info.attrs['boundary_source'] = 'Spectra (+ Wind)'
        if 'interpolation' not in self.info.attrs:
            self.info.attrs['interpolation'] = 'IDW'

        if 'num_frequencies' not in self.info.attrs:
            self.info.attrs['num_frequencies'] = 30
        if 'delta_frequency' not in self.info.attrs:
            self.info.attrs['delta_frequency'] = 0.01
        if 'min_frequency' not in self.info.attrs:
            self.info.attrs['min_frequency'] = 0.04
        if 'angle_convention' not in self.info.attrs:
            self.info.attrs['angle_convention'] = 'Shore normal'
        self._load_boundary_spectral_default_values()

    def _load_boundary_spectral_default_values(self):
        """Loads all of the boundary spectral default values that are not in the info attributes already."""
        if 'location_coverage' not in self.info.attrs:
            self.info.attrs['location_coverage'] = 0
        if 'location_coverage_uuid' not in self.info.attrs:
            self.info.attrs['location_coverage_uuid'] = ''

        if 'side1' not in self.info.attrs:
            self.info.attrs['side1'] = 'Specified spectrum'
        if 'side3' not in self.info.attrs:
            self.info.attrs['side3'] = 'Zero spectrum'

    def _load_output_default_values(self):
        """Loads all of the output default values that are not in the info attributes already."""
        if 'limit_observation_output' not in self.info.attrs:
            self.info.attrs['limit_observation_output'] = 0
        if 'rad_stress' not in self.info.attrs:
            self.info.attrs['rad_stress'] = 0
        if 'breaking_type' not in self.info.attrs:
            self.info.attrs['breaking_type'] = const.TEXT_NONE
        if 'nesting' not in self.info.attrs:
            self.info.attrs['nesting'] = 0
        if 'nesting_uuid' not in self.info.attrs:
            self.info.attrs['nesting_uuid'] = ''
        if 'observation' not in self.info.attrs:
            self.info.attrs['observation'] = 0
        if 'observation_uuid' not in self.info.attrs:
            self.info.attrs['observation_uuid'] = ''

    def _load_options_default_values(self):
        """Loads all of the options default values that are not in the info attributes already."""
        # Used to be stuff here for the Options control area
        if 'wet_dry' not in self.info.attrs:
            self.info.attrs['wet_dry'] = 0
        if 'sea_swell' not in self.info.attrs:
            self.info.attrs['sea_swell'] = 0
        if 'infragravity' not in self.info.attrs:
            self.info.attrs['infragravity'] = 0
        if 'diffraction_option' not in self.info.attrs:
            self.info.attrs['diffraction_option'] = 1
        if 'diffraction_intensity' not in self.info.attrs:
            self.info.attrs['diffraction_intensity'] = 4.0
        if 'nonlinear_wave' not in self.info.attrs:
            self.info.attrs['nonlinear_wave'] = 0
        if 'roller' not in self.info.attrs:
            self.info.attrs['roller'] = 0  # 'OFF' by default
        if 'runup' not in self.info.attrs:
            self.info.attrs['runup'] = 0
        if 'fastmode' not in self.info.attrs:
            self.info.attrs['fastmode'] = 0
        if 'forward_reflection' not in self.info.attrs:
            self.info.attrs['forward_reflection'] = const.TEXT_NONE
        if 'forward_reflection_uuid' not in self.info.attrs:
            self.info.attrs['forward_reflection_uuid'] = ''
        if 'forward_reflection_const' not in self.info.attrs:
            self.info.attrs['forward_reflection_const'] = 0.5
        if 'backward_reflection' not in self.info.attrs:
            self.info.attrs['backward_reflection'] = const.TEXT_NONE
        if 'backward_reflection_uuid' not in self.info.attrs:
            self.info.attrs['backward_reflection_uuid'] = ''
        if 'backward_reflection_const' not in self.info.attrs:
            self.info.attrs['backward_reflection_const'] = 0.3
        if 'muddy_bed' not in self.info.attrs:
            self.info.attrs['muddy_bed'] = const.TEXT_NONE
        if 'muddy_bed_uuid' not in self.info.attrs:
            self.info.attrs['muddy_bed_uuid'] = ''
        if 'wave_breaking_formula' not in self.info.attrs:
            self.info.attrs['wave_breaking_formula'] = 'Extended Goda'
        if 'gamma_bj78' not in self.info.attrs:
            self.info.attrs['gamma_bj78'] = 0.6
        if not self.info.attrs.get('date_format'):
            self.info.attrs['date_format'] = '12 digits'

    def _load_take_values(self):
        """Load the take values."""
        if 'spectral_uuid' not in self.info.attrs:
            self.info.attrs['spectral_uuid'] = ''
        if 'spectral2_uuid' not in self.info.attrs:
            self.info.attrs['spectral2_uuid'] = ''
        self.info.attrs.pop('structures_uuid', None)
        self.info.attrs.pop('grid_uuid', None)

    def load_all(self):
        """Loads all datasets from the file."""
        _ = self.info
        _ = self.case_times
        self.close()

    def commit(self):
        """Save current in-memory component parameters to data file."""
        self.info.attrs['VERSION'] = version('xmscmswave')
        super().commit()  # Recreates the NetCDF file if vacuuming
        self._drop_h5_groups(['case_times'])
        # write
        if self._case_times is not None:
            self._case_times.to_netcdf(self._filename, group='case_times', mode='a')

    def close(self):
        """Closes the H5 file and does not write any data that is in memory."""
        super().close()
        if self._case_times is not None:
            self._case_times.close()

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space."""
        # Ensure all datasets are loaded into memory.
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._case_times is None:
            self._case_times = self.get_dataset('case_times', False)
        try:
            os.remove(self._filename)  # Delete the existing NetCDF file
        except Exception:
            pass
        self.commit()  # Rewrite all datasets
