"""STWAVE files writer."""
# 1. Standard python modules
import datetime

# 2. Third party modules

# 3. Aquaveo modules
from xms.guipy.time_format import ISO_DATETIME_FORMAT
from xms.stwave.data.simulation_data import SimulationData
from xms.stwave.file_io.export_sim import ExportSim

# 4. Local modules

__copyright__ = "(C) Copyright Aquaveo 2023"
__license__ = "All rights reserved"


class StwaveWriter:
    """Writer class for the STWAVE files from CSTORM."""
    def __init__(self, logger, xms_data, cstorm_data):
        """Constructor.

        Args:
            logger (:obj:`loggin.Logger`): the logger
            xms_data (:obj:`XmsData`): Simulation data retrieved from SMS
            cstorm_data (:obj:`dict`): fills in data for CSTORM config file
        """
        self._logger = logger
        self._xms_data = xms_data
        self._cstorm_data = cstorm_data

    def write(self):
        """Writes the STWAVE simulation for the CSTORM model."""
        self._logger.info('Writing STWAVE simulations')
        if not self._xms_data.stwave_sims:
            self._logger.error('No STWAVE simulations included in CSTORM simulation. Aborting.')
            return
        template = self._cstorm_data.get('template', False)
        self._cstorm_data['stwgrids'] = len(self._xms_data.stwave_sims)
        self._cstorm_data['simfile'] = []
        self._cstorm_data['stwprocs'] = []
        self._cstorm_data['st_times'] = []
        self._cstorm_data['st_grid_prj'] = []
        for idx, stsim in enumerate(self._xms_data.stwave_sims):
            sim, sim_comp = stsim
            self._logger.info(f'Writing {sim.name} - STWAVE simulation')
            xms_data = {
                'query': self._xms_data.query,
                'logger': self._logger,
                'sim_export': True,
                'sim_uuid': sim.uuid,
                'sim_comp': sim_comp,
                'stwprocs': 1,
                'cstorm_export': True,
            }
            writer = ExportSim(xms_data, template=template)
            writer.set_logger(self._logger)
            writer.export()
            self._cstorm_data['simfile'].append(f'{sim.name}.sim')
            self._cstorm_data['stwprocs'].append(xms_data['stwprocs'])
            self._cstorm_data['st_times'].append(xms_data['st_times'])
            self._cstorm_data['st_grid_prj'].append(xms_data['grid_prj'])
            if idx == 0:
                simd = SimulationData(sim_comp.main_file)
                self._cstorm_data['st_reftime'] = simd.info.attrs['reftime']
                self._cstorm_data['st_times_seconds'] = simd.times_in_seconds()
            if xms_data['propagation'] != 0:
                self._logger.error(
                    f'STWAVE simulation "{sim.name}" must set "Source terms" to "Source terms and propagation" option. '
                    'Change this option in the STWAVE model control.'
                )
        self._check_stwave_timing()
        self._check_grids_coordinate_system()

    def _check_stwave_timing(self):
        """Check that the STWAVE simulation times are compatible with CSTORM."""
        # first the time must be the same for all STWAVE models
        st_times = self._cstorm_data['st_times']
        for i in range(1, len(st_times)):
            if st_times[0] != st_times[i]:
                sim_0 = self._xms_data.stwave_sims[0][0].name
                sim_1 = self._xms_data.stwave_sims[i][0].name
                msg = 'Timing for all STWAVE models must be the same; these simulations do not match: ' \
                      f'"{sim_0}", "{sim_1}.'
                self._logger.error(msg)

        # second the reference time must be within the simulation time for adcirc
        adc_time = self._cstorm_data['adc_reftime']
        st_time = self._cstorm_data['st_reftime']
        adc_ref_time = datetime.datetime.strptime(adc_time, ISO_DATETIME_FORMAT)
        st_ref_time = datetime.datetime.strptime(st_time, ISO_DATETIME_FORMAT)
        dt = st_ref_time - adc_ref_time
        if dt.total_seconds() < 0:
            msg = 'The STWAVE reference time must be equal to or come after the ADCIRC reference time. ' \
                  f'STWAVE reference time: "{st_time}". ADCIRC reference time: "{adc_time}".'
            self._logger.error(msg)
        # calc the STWAVE start time relative to ADCIRC start time
        st_start = dt.total_seconds() / self._cstorm_data['DTDP']

        # the delta time between times must be constant
        st_time_sec = self._cstorm_data['st_times_seconds']
        dt_set = set()
        for i in range(1, len(st_time_sec)):
            dt_set.add(st_time_sec[i] - st_time_sec[i - 1])
        if len(dt_set) > 1:
            sim = self._xms_data.stwave_sims[0][0].name
            msg = 'The STWAVE case times must offset by a constant time increment. ' \
                  f'Simulation "{sim}" has the following time increments in seconds: {dt_set}.'
            self._logger.error(msg)

        # the end time must be within the simulation time for adcirc
        st_end = st_start + st_time_sec[-1] / self._cstorm_data['DTDP']
        adc_finish = self._cstorm_data['adcfinish']
        if st_end > adc_finish:
            msg = 'The STWAVE ending case time must be before or equal to the ADCIRC ending time. ' \
                  f'STWAVE end: {st_end}. ADCIRC end: {adc_finish}.'
            self._logger.error(msg)
        self._cstorm_data['stwstart'] = st_start
        self._cstorm_data['stwfinish'] = st_end
        self._cstorm_data['stwtiminc'] = list(dt_set)[0] / self._cstorm_data['DTDP']

    def _check_grids_coordinate_system(self):
        """Check that the STWAVE grids are compatible with CSTORM."""
        if self._xms_data.query.display_projection.coordinate_system == 'GEOGRAPHIC':
            # if ADCIRC is in geographic then the STWAVE models must be in state plane or utm
            for i, prj in enumerate(self._cstorm_data['st_grid_prj']):
                prj_sys = prj.coordinate_system
                if prj_sys not in ['UTM', 'STATEPLANE']:
                    sim_name = self._xms_data.stwave_sims[i][0].name
                    msg = 'CSTORM requires UTM or STATEPLANE coordinate system for STWAVE models when ' \
                          'ADCIRC is using GEOGRAPHIC coordinate system. Update the grid used by simulation ' \
                          f'"{sim_name}" to be in UTM or STATEPLANE.'
                    self._logger.error(msg)
        # else:  # if ADCIRC is not geographic then the STWAVE models must be in the same coordinate system
        #     pass  # TODO once we have an example
