"""SimData class."""

# 1. Standard Python modules
import datetime
import os

# 2. Third party modules
import numpy as np
import pandas as pd
import xarray as xr

# 3. Aquaveo modules
from xms.components.bases.xarray_base import XarrayBase

# 4. Local modules
SIM_COMP_MAIN_FILE = 'sim_comp.nc'
LATERAL_BC_TYPES = ['Pinned', 'Gated', 'Moving']
LATERAL_DISP_TYPES = ['Simulation Period', 'Day', 'Timestep']
MONTE_CARLO_PROB_TYPES = ['Rayleigh Distribution', 'Rayleigh+Weibull', 'User Specified']
WAVE_COMPONENTS = ['Primary (1)']


class SimData(XarrayBase):
    """Manages data file for the hidden simulation component."""

    def __init__(self, data_file):
        """Initializes the data class.

        Args:
            data_file (:obj:`str`): The netcdf file (with path) associated with this instance data. Probably the owning
                component's main file.
        """
        self._filename = data_file
        self._info = None
        self._model = None
        self._beach = None
        self._seaward = None
        self._lateral = None
        self._adaptive = None
        self._cross_shore = None
        self._monte_carlo = None
        self._water_level = None
        self._wl_table = None
        self._print_table = None
        self._proj_units = None

        # Create the default file before calling super because we have our own attributes to write.
        self._get_default_datasets(data_file)
        super().__init__(data_file)
        self._check_for_simple_migrations()

    def _get_default_datasets(self, data_file):
        """Create default datasets if needed.

        Args:
            data_file (:obj:`str`): Name of the data file. If it doesn't exist, it will be created.

        Returns:
            (:obj:`bool`): Returns True if datasets were created.
        """
        if not os.path.exists(data_file) or not os.path.isfile(data_file):
            info = {
                'FILE_TYPE': 'GENCADE_SIMULATION',
                # 'VERSION': pkg_resources.get_distribution('xmsgencade').version,
                'proj_dir': '',  # Location of the saved project, if it exists.
                'units': ''      # Imperial or Metric, determined in Model Control
            }
            self._info = xr.Dataset(attrs=info)

            # All model variables will be exported to .gen file.
            model = {
                'title': 'SMS Simulation',
                'full_print': 1,   # 0 - unchecked, 1 - checked
                'start_date': '',  # Reference date used for simulation
                'end_date': '',
                'time_step': 1.0,
                'recording': 168.0,
                'enable_water_level': 0,   # I only put this here, because having its own attrs section wasn't working.
            }
            self._model = xr.Dataset(attrs=model)

            # All model variables will be exported to .gen file except those indicated below.
            beach = {
                'eff_grain_size': 0.2,
                'avg_berm_ht': 1.0,
                'closure_depth': 10.0,
                'regional_slc': 0.0,            # exported to .cntmcxsh file
                'regional_subsidence': 0.0,     # exported to .cntmcxsh file
                'beach_slope': 0.01,            # exported to .cntmcxsh file
                'K1': 0.4,
                'K2': 0.25,
                'KTIDE': 1.0,
            }
            self._beach = xr.Dataset(attrs=beach)

            # All model variables will be exported to .gen file.
            seaward = {
                'ht_amp_factor': 1.0,
                'ang_amp_factor': 1.0,
                'ang_offset': 0.0,
                'wave_components': 'Primary (1)',
                'num_cells_smoothing': 11,
            }
            self._seaward = xr.Dataset(attrs=seaward)

            # All model variables will be exported to .gen file.
            lateral = {
                'left_bc_type': 'Pinned',
                'left_length_to_tip': 0.0,
                'left_shoreline_disp_distance': 0.0,
                'left_disp_type': 'Simulation Period',
                'rt_bc_type': 'Pinned',
                'rt_length_to_tip': 0.0,
                'rt_shoreline_disp_distance': 0.0,
                'rt_disp_type': 'Simulation Period',
            }
            self._lateral = xr.Dataset(attrs=lateral)

            # All model variables will be exported to .gen file.
            adaptive = {
                'use_adaptive_ts': 0,
                'threshold_min': 0.0001,
                'threshold_max': 1.0,
                'days_stable': 28,
                'stability_min': 0.5,
                'stability_max': 0.5,
            }
            self._adaptive = xr.Dataset(attrs=adaptive)

            # All model variables will be exported to .cntmcxsh file.
            cross_shore = {
                'use_cross_shore': 0,
                'cross_shore_scale': 1.0,
                'onshore_trans_rate': 0.0,
                'use_var_berm_ht': 0,
                'use_var_slope': 0.01,
            }
            self._cross_shore = xr.Dataset(attrs=cross_shore)

            # All model variables will be exported to .cntmcxsh file.
            monte_carlo = {
                'use_monte_carlo': 0,
                'prob_function': 'Rayleigh Distribution',
                'AL_mean': 0.0,
                'AL_sigma': 5.0,
                'AL_max': 45.0,
                'num_mc_sims': 2,
                'AAO_val': 0.6,
                'BBO_val': 2.4,
                'FKO_val': 1.1,
                'HCUT_val': 0.0,
                'H_min': 0.1,
                'H_max': 7.0,
                'H_mean': 1.0,
                'wave_interval': 1.0,
                'use_beach_fill': 0,
                'std_percent': 0.1,
            }
            self._monte_carlo = xr.Dataset(attrs=monte_carlo)

            print_table = {
                'print_date': xr.DataArray(data=np.array([datetime.date(2000, 1, 1)], dtype='datetime64[ns]')),
            }
            self._print_table = xr.Dataset(data_vars=print_table)

            wl_table = {
                'wl_date': xr.DataArray(data=np.array([datetime.date(2000, 1, 1)], dtype='datetime64[ns]')),
                'water_level': xr.DataArray(data=np.array([0.0], dtype=float)),
            }
            self._wl_table = xr.Dataset(data_vars=wl_table)

            if not os.path.exists(data_file):
                self.commit()

    def _check_for_simple_migrations(self):
        """Check for easy fixes we can make to the file to avoid doing it during DMI project migration."""
        commit = False
        if commit:
            self.commit()

    def commit(self):
        """Save current in-memory component parameters to data file."""
        super().commit()  # Recreates the NetCDF file if vacuuming
        if self._model is not None:
            self._model.close()
            self._model.to_netcdf(self._filename, group='model', mode='a')
        if self._beach is not None:
            self._beach.close()
            self._beach.to_netcdf(self._filename, group='beach', mode='a')
        if self._seaward is not None:
            self._seaward.close()
            self._seaward.to_netcdf(self._filename, group='seaward', mode='a')
        if self._lateral is not None:
            self._lateral.close()
            self._lateral.to_netcdf(self._filename, group='lateral', mode='a')
        if self._adaptive is not None:
            self._adaptive.close()
            self._adaptive.to_netcdf(self._filename, group='adaptive', mode='a')
        if self._cross_shore is not None:
            self._cross_shore.close()
            self._cross_shore.to_netcdf(self._filename, group='cross_shore', mode='a')
        if self._monte_carlo is not None:
            self._monte_carlo.close()
            self._monte_carlo.to_netcdf(self._filename, group='monte_carlo', mode='a')
        if self._water_level is not None:
            self._water_level.close()
            self._water_level.to_netcdf(self._filename, group='water_level', mode='a')
        if self._wl_table is not None:
            self._wl_table.close()
            self._drop_h5_groups(['wl_table'])
            self._wl_table.to_netcdf(self._filename, group='wl_table', mode='a')
        if self._print_table is not None:
            self._print_table.close()
            self._drop_h5_groups(['print_table'])
            self._print_table.to_netcdf(self._filename, group='print_table', mode='a')

    def vacuum(self):
        """Rewrite all SimData to a new/wiped file to reclaim disk space."""
        # Ensure all datasets are loaded into memory.
        if self._info is None:
            self._info = self.get_dataset('info', False)
        if self._model is None:
            self._model = self.get_dataset('model', False)
        if self._beach is None:
            self._beach = self.get_dataset('beach', False)
        if self._seaward is None:
            self._seaward = self.get_dataset('seaward', False)
        if self._lateral is None:
            self._lateral = self.get_dataset('lateral', False)
        if self._adaptive is None:
            self._adaptive = self.get_dataset('adaptive', False)
        if self._cross_shore is None:
            self._cross_shore = self.get_dataset('cross_shore', False)
        if self._monte_carlo is None:
            self._monte_carlo = self.get_dataset('monte_carlo', False)
        if self._water_level is None:
            self._water_level = self.get_dataset('water_level', False)
        if self._wl_table is None:
            self._wl_table = self.get_dataset('wl_table', False)
        if self._print_table is None:
            self._print_table = self.get_dataset('print_table', False)

        try:
            os.remove(self._filename)  # Delete the existing NetCDF file
        except Exception:
            pass
        self.commit()  # Rewrite all datasets

    @property
    def model(self):
        """Load the model dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the model datasets in the main file
        """
        if self._model is None:
            self._model = self.get_dataset('model', False)
        return self._model

    @model.setter
    def model(self, dset):
        """Setter for the model attribute."""
        if dset:
            self._model = dset

    @property
    def beach(self):
        """Load the beach dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the beach datasets in the main file
        """
        if self._beach is None:
            self._beach = self.get_dataset('beach', False)
        return self._beach

    @beach.setter
    def beach(self, dset):
        """Setter for the beach attribute."""
        if dset:
            self._beach = dset

    @property
    def seaward(self):
        """Load the seaward dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the seaward datasets in the main file
        """
        if self._seaward is None:
            self._seaward = self.get_dataset('seaward', False)
        return self._seaward

    @seaward.setter
    def seaward(self, dset):
        """Setter for the seaward attribute."""
        if dset:
            self._seaward = dset

    @property
    def lateral(self):
        """Load the lateral dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the lateral datasets in the main file
        """
        if self._lateral is None:
            self._lateral = self.get_dataset('lateral', False)
        return self._lateral

    @lateral.setter
    def lateral(self, dset):
        """Setter for the lateral attribute."""
        if dset:
            self._lateral = dset

    @property
    def adaptive(self):
        """Load the adaptive dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the adaptive datasets in the main file
        """
        if self._adaptive is None:
            self._adaptive = self.get_dataset('adaptive', False)
        return self._adaptive

    @adaptive.setter
    def adaptive(self, dset):
        """Setter for the adaptive attribute."""
        if dset:
            self._adaptive = dset

    @property
    def cross_shore(self):
        """Load the cross_shore dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the cross_shore datasets in the main file
        """
        if self._cross_shore is None:
            self._cross_shore = self.get_dataset('cross_shore', False)
        return self._cross_shore

    @cross_shore.setter
    def cross_shore(self, dset):
        """Setter for the cross_shore attribute."""
        if dset:
            self._cross_shore = dset

    @property
    def monte_carlo(self):
        """Load the monte_carlo dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the monte_carlo datasets in the main file
        """
        if self._monte_carlo is None:
            self._monte_carlo = self.get_dataset('monte_carlo', False)
        return self._monte_carlo

    @monte_carlo.setter
    def monte_carlo(self, dset):
        """Setter for the monte_carlo attribute."""
        if dset:
            self._monte_carlo = dset

    @property
    def print_table(self):
        """Load the print_table dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the print_table datasets in the main file.
        """
        if self._print_table is None:
            self._print_table = self.get_dataset('print_table', False)
        return self._print_table

    @print_table.setter
    def print_table(self, dset):
        """Setter for the list_1_table attribute."""
        if dset:
            self._print_table = dset

    @staticmethod
    def default_print_table():
        """Creates an xarray.Dataset for a print table."""
        default_data = {
            'print_date': [datetime.date(2000, 1, 1)]
        }
        return pd.DataFrame(default_data).to_xarray()

    @property
    def wl_table(self):
        """Load the wl_table dataset from disk.

        Returns:
            (:obj:`xarray.Dataset`): Dataset interface to the wl_table datasets in the main file
        """
        if self._wl_table is None:
            self._wl_table = self.get_dataset('wl_table', False)
        return self._wl_table

    @wl_table.setter
    def wl_table(self, dset):
        """Setter for the wl_table attribute."""
        if dset:
            self._wl_table = dset

    @staticmethod
    def default_wl_table():
        """Creates an xarray.Dataset for a wl_table."""
        default_data = {
            'wl_date': [datetime.datetime(2000, 1, 1)],
            'water_level': [0.0]
        }
        return pd.DataFrame(default_data).to_xarray()
