"""Simulation data class."""

__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import datetime
from importlib.metadata import version
import os

# 2. Third party modules
import numpy as np
import pandas as pd
from PySide2.QtCore import QDateTime
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase
from xms.guipy.time_format import ISO_DATETIME_FORMAT

# 4. Local modules
from xms.stwave.data import simulation_data
from xms.stwave.data import stwave_consts as const
from xms.stwave.gui import gui_util


def case_data_table(times, wind_dir, wind_mag, water_level):
    """Creates case time data.

    Args:
        times (:obj:`list`): The simulation times
        wind_dir (:obj:`list`): The case wind directions
        wind_mag (:obj:`list`): The case wind magnitudes
        water_level (:obj:`list`): The case water levels
    """
    return {
        'Time': xr.DataArray(data=times),
        'Wind Direction': xr.DataArray(data=np.array(wind_dir, dtype=np.float64)),
        'Wind Magnitude': xr.DataArray(data=np.array(wind_mag, dtype=np.float64)),
        'Water Level': xr.DataArray(data=np.array(water_level, dtype=np.float64)),
    }


class SimulationData(XarrayBase):
    """Manages data file for the hidden simulation component."""

    def __init__(self, filename):
        """
        Initializes the data class.

        Args:
            filename (str): The name of the main file that data is stored in.
        """
        super().__init__(filename.strip('"\''))
        self._convert_times = False  # For migrating times that were stored as seconds
        self._case_times = None
        self.info.attrs['FILE_TYPE'] = 'STWAVE_SIMULATION'
        self._load_parameter_default_values()
        self._load_boundary_default_values()
        self._load_output_default_values()
        self._load_iteration_default_values()
        self._load_take_values()
        self.load_all()
        self._check_for_simple_migrations()

    def _check_for_simple_migrations(self):
        """Check for migrations that do not need an XML-based migration script."""
        # This addresses a bug in the Model Control that coerced to float dtype to int in some cases.
        if pd.api.types.is_integer_dtype(self.case_times['Time'].dtype):
            self.case_time = simulation_data.case_data_table(self.case_times['Time'],
                                                             self.case_times['Wind Direction'],
                                                             self.case_times['Wind Magnitude'],
                                                             self.case_times['Water Level'])
        # Seconds used to be the default but is not really an option.
        if self._convert_times or self.info.attrs['reftime_units'] == 'seconds':
            self._convert_times = False
            self.info.attrs['reftime_units'] = const.TIME_UNITS_HOURS
            self.case_times['Time'] /= 3600.0
        # Check the reftime to make sure we can convert it to a data
        try:
            qreftime = QDateTime.fromString(self.info.attrs['reftime'])
            reftime = gui_util.qdatetime_to_datetime(qreftime)
        except Exception:
            try:
                reftime = datetime.datetime.strptime(self.info.attrs['reftime'], ISO_DATETIME_FORMAT)
            except Exception:
                reftime = datetime.datetime(1950, 1, 1)
        self.info.attrs['reftime'] = reftime.strftime(ISO_DATETIME_FORMAT)

    @property
    def case_times(self):
        """
        Get the case times dataset.

        Returns:
            (:obj:`xarray.Dataset`): The case_times list dataset.
        """
        if self._case_times is None:
            self._case_times = self.get_dataset('case_times', False)
            if self._case_times is None:
                self._case_times = self._default_case_times()
            else:
                # the case times are now ordered "Dir" and then "Map"
                # Because they used to be ordered "Mag" then "Dir" and our code just used indices (no headers)
                # we may have to swap the columns. These 4 lines swap the columns if needed
                df = self._default_case_times().to_dataframe()
                df2 = self._case_times.to_dataframe()
                df2_cols = list(df2.columns)
                df2.rename(columns={df2_cols[idx]: col for idx, col in enumerate(df.columns) if idx < len(df2_cols)})
                self._case_times = df2.to_xarray()
        return self._case_times

    @case_times.setter
    def case_times(self, value):
        """
        Sets the case time data.

        Args:
            value: The case time data.
        """
        self._case_times = value

    @staticmethod
    def _default_case_times():
        """
        Creates a default case times data set.

        Returns:
            (:obj:`xarray.Dataset`): The case times dataset.
        """
        case_time_data = simulation_data.case_data_table([], [], [], [])
        return xr.Dataset(data_vars=case_time_data)

    def _load_parameter_default_values(self):
        """Loads all of the default values that are not a dataset and not in the info attributes already."""
        if 'reftime_units' not in self.info.attrs:
            # If units don't exist but the file does, the existing times are stored as seconds and need to be converted.
            if os.path.isfile(self._filename):
                self._convert_times = True
            self.info.attrs['reftime_units'] = const.TIME_UNITS_HOURS
        if 'reftime' not in self.info.attrs:
            self.info.attrs['reftime'] = '1950-01-01 00:00:00'
        if 'plane' not in self.info.attrs:
            self.info.attrs['plane'] = const.PLANE_TYPE_HALF
        if 'source_terms' not in self.info.attrs:
            self.info.attrs['source_terms'] = const.SOURCE_PROP_AND_TERMS

        if 'depth' not in self.info.attrs:
            self.info.attrs['depth'] = const.DEP_OPT_NONTRANSIENT
        if 'depth_uuid' not in self.info.attrs:
            self.info.attrs['depth_uuid'] = ''

        if 'current_interaction' not in self.info.attrs:
            self.info.attrs['current_interaction'] = const.OPT_NONE
        if 'current_uuid' not in self.info.attrs:
            self.info.attrs['current_uuid'] = ''

        if 'processors_i' not in self.info.attrs:
            self.info.attrs['processors_i'] = 1
        if 'processors_j' not in self.info.attrs:
            self.info.attrs['processors_j'] = 1
        self._load_parameter_bottom_friction_default_values()
        self._load_parameter_fields_default_values()

    def _load_parameter_bottom_friction_default_values(self):
        """Loads all of the default values that are not a dataset and not in the info attributes already."""
        if 'friction' not in self.info.attrs:
            self.info.attrs['friction'] = const.OPT_NONE
        if 'JONSWAP' not in self.info.attrs:
            self.info.attrs['JONSWAP'] = 0.03
        if 'JONSWAP_uuid' not in self.info.attrs:
            self.info.attrs['JONSWAP_uuid'] = ''
        if 'manning' not in self.info.attrs:
            self.info.attrs['manning'] = 0.03
        if 'manning_uuid' not in self.info.attrs:
            self.info.attrs['manning_uuid'] = ''

    def _load_parameter_fields_default_values(self):
        """Loads all of the default values that are not a dataset and not in the info attributes already."""
        if 'surge' not in self.info.attrs:
            self.info.attrs['surge'] = const.OPT_CONST
        if 'surge_uuid' not in self.info.attrs:
            self.info.attrs['surge_uuid'] = ''

        if 'wind' not in self.info.attrs:
            self.info.attrs['wind'] = const.OPT_CONST
        if 'wind_uuid' not in self.info.attrs:
            self.info.attrs['wind_uuid'] = ''

        if 'ice' not in self.info.attrs:
            self.info.attrs['ice'] = const.OPT_NONE
        if 'ice_uuid' not in self.info.attrs:
            self.info.attrs['ice_uuid'] = ''
        if 'ice_threshold' not in self.info.attrs:
            self.info.attrs['ice_threshold'] = 50.0

    def _load_boundary_default_values(self):
        """Loads all of the boundary default values that are not in the info attributes already."""
        if 'boundary_source' not in self.info.attrs:
            self.info.attrs['boundary_source'] = const.SPEC_OPT_COV
        if 'interpolation' not in self.info.attrs:
            self.info.attrs['interpolation'] = const.INTERP_OPT_MORPHIC

        if 'num_frequencies' not in self.info.attrs:
            self.info.attrs['num_frequencies'] = 30
        if 'delta_frequency' not in self.info.attrs:
            self.info.attrs['delta_frequency'] = 0.01
        if 'min_frequency' not in self.info.attrs:
            self.info.attrs['min_frequency'] = 0.04
        if 'angle_convention' not in self.info.attrs:
            self.info.attrs['angle_convention'] = const.ANG_CONV_SHORE_NORMAL
        self._load_boundary_spectral_default_values()

    def _load_boundary_spectral_default_values(self):
        """Loads all of the boundary spectral default values that are not in the info attributes already."""
        if 'location_coverage' not in self.info.attrs:
            self.info.attrs['location_coverage'] = 0
        if 'location_coverage_uuid' not in self.info.attrs:
            self.info.attrs['location_coverage_uuid'] = ''

        if 'side1' not in self.info.attrs:
            self.info.attrs['side1'] = const.I_BC_SPECIFIED
        if 'side2' not in self.info.attrs:
            self.info.attrs['side2'] = const.I_BC_LATERAL
        if 'side3' not in self.info.attrs:
            self.info.attrs['side3'] = const.I_BC_ZERO
        if 'side4' not in self.info.attrs:
            self.info.attrs['side4'] = const.I_BC_LATERAL

    def _load_output_default_values(self):
        """Loads all of the output default values that are not in the info attributes already."""
        if 'rad_stress' not in self.info.attrs:
            self.info.attrs['rad_stress'] = 0
        if 'c2shore' not in self.info.attrs:
            self.info.attrs['c2shore'] = 0
        if 'output_stations' not in self.info.attrs:
            self.info.attrs['output_stations'] = 0
        if 'output_stations_uuid' not in self.info.attrs:
            self.info.attrs['output_stations_uuid'] = ''
        if 'nesting' not in self.info.attrs:
            self.info.attrs['nesting'] = 0
        if 'nesting_uuid' not in self.info.attrs:
            self.info.attrs['nesting_uuid'] = ''
        if 'monitoring' not in self.info.attrs:
            self.info.attrs['monitoring'] = 0
        if 'monitoring_uuid' not in self.info.attrs:
            self.info.attrs['monitoring_uuid'] = ''
        if 'breaking_type' not in self.info.attrs:
            self.info.attrs['breaking_type'] = const.BREAK_OPT_NONE

    def _load_iteration_default_values(self):
        """Loads all of the iteration default values that are not in the info attributes already."""
        if 'max_init_iters' not in self.info.attrs:
            self.info.attrs['max_init_iters'] = 20
        if 'init_iters_stop_value' not in self.info.attrs:
            self.info.attrs['init_iters_stop_value'] = 0.1
        if 'init_iters_stop_percent' not in self.info.attrs:
            self.info.attrs['init_iters_stop_percent'] = 100.0
        if 'max_final_iters' not in self.info.attrs:
            self.info.attrs['max_final_iters'] = 20
        if 'final_iters_stop_value' not in self.info.attrs:
            self.info.attrs['final_iters_stop_value'] = 0.1
        if 'final_iters_stop_percent' not in self.info.attrs:
            self.info.attrs['final_iters_stop_percent'] = 99.8

    def _load_take_values(self):
        """Load the take values."""
        if 'spectral_uuid' not in self.info.attrs:
            self.info.attrs['spectral_uuid'] = ''
        if 'grid_uuid' not in self.info.attrs:
            self.info.attrs['grid_uuid'] = ''

    def load_all(self):
        """Loads all datasets from the file."""
        _ = self.info
        _ = self.case_times
        self.close()

    def commit(self):
        """Save current in-memory component parameters to data file."""
        self.info.attrs['VERSION'] = version('xmsstwave')
        super().commit()  # Recreates the NetCDF file if vacuuming
        # write package version here
        self._drop_h5_groups(['case_times'])
        # write
        if self._case_times is not None:
            self._case_times.to_netcdf(self._filename, group='case_times', mode='a')

    def close(self):
        """Closes the H5 file and does not write any data that is in memory."""
        super().close()
        if self._case_times is not None:
            self._case_times.close()

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space."""
        # Ensure all datasets are loaded into memory.
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._case_times is None:
            self._case_times = self.get_dataset('case_times', False)
        try:
            os.remove(self._filename)  # Delete the existing NetCDF file
        except Exception:
            pass
        self.commit()  # Rewrite all datasets

    def times_in_seconds(self):
        """Returns an array of the case times as second offsets from the simulation reference time.

        Returns:
            np.ndarray: See description.
        """
        reftime_units = self.info.attrs['reftime_units']
        multiplier = 60.0  # reftime_units == const.TIME_UNITS_MINUTES:
        if reftime_units == const.TIME_UNITS_HOURS:
            multiplier = 3600.0
        elif reftime_units == const.TIME_UNITS_DAYS:
            multiplier = 3600.0 * 24.0
        return self.case_times['Time'].data * multiplier
