"""A writer for STWAVE spatial datasets."""

# 1. Standard Python modules
import copy
import datetime
from io import StringIO
import logging
import math
import os
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import Query
from xms.api.tree import tree_util
from xms.constraint import read_grid_from_file
from xms.core.filesystem import filesystem as xfs
from xms.data_objects.parameters import datetime_to_julian
from xms.datasets.dataset_reader import DatasetReader
from xms.guipy.time_format import ISO_DATETIME_FORMAT

# 4. Local modules
from xms.stwave.data import stwave_consts as const
from xms.stwave.data.simulation_data import SimulationData


# -------------------------------------------------------------------------------
# Code adapted from WriteDatasetFile class in SMS
# -------------------------------------------------------------------------------


# convert cartesian convention angles to shore normal (-180 to 180 degrees)
def cart_to_shore_normal(a_from, a_to):
    """Converts cartesian angle to shore-normal.

    Args:
        a_from (float): Starting angle in degrees.
        a_to (float): Ending angle in degrees.

    Returns:
        (float): The resulting degree angle.

    """
    a_from -= a_to
    while a_from > 180.0:
        a_from -= 360.0
    while a_from < -180.0:
        a_from += 360.0
    return a_from


# convert meteorologic convention angles to shore normal (-180 to 180 degrees)
def meteor_to_shore_normal(a_from, a_to):
    """Converts meteorological angle to shore-normal.

    Args:
        a_from (float): Starting angle in degrees.
        a_to (float): Ending angle in degrees.

    Returns:
        (float): The resulting degree angle.

    """
    a_from = 270.0 - a_from
    return cart_to_shore_normal(a_from, a_to)


# convert oceanographic convention angles to shore normal (-180 to 180 degrees)
def ocean_to_shore_normal(a_from, a_to):
    """Converts oceanographic angle to shore-normal.

    Args:
        a_from (float): Starting angle in degrees.
        a_to (float): Ending angle in degrees.

    Returns:
        (float): The resulting degree angle.
    """
    a_from = 90.0 - a_from
    return cart_to_shore_normal(a_from, a_to)


# get time in 'yyyymmddhhmm' format
def get_time_string(a_time):
    """Gets the formatted string version of a time object.

    Args:
        a_time (datetime.datetime): The time object to reference.

    Return:
        (str): The formatted string showing time.

    """
    # STWAVE only keeps minute resolution, round to the nearest one
    hours = a_time.time().hour
    minutes = a_time.time().minute
    seconds = a_time.time().second
    if seconds >= 30:
        minutes += 1
        if minutes >= 60:
            hours += 1
            minutes = 0
            if hours >= 24:  # truncate instead if rounding pushes us into another day
                hours = a_time.time().hour
                minutes = a_time.time().minute

    return f"{a_time.date().year}{a_time.date().month:0>2d}{a_time.date().day:0>2d}{hours:0>2d}{minutes:0>2d}"


# finds most appropriate timestep for given time, interpolating if needed
def get_timestep_vals(dset, time, reftime):  # a_time and a_ref_time are Julian doubles
    """Gets the data from a particular timestep in a given dataset.

    Args:
        dset (DatasetReader): Dataset to reference
        time (double): Timestep to get data from, as a Julian double (legacy garbage)
        reftime (double): Reference time value in case the dataset has none, as a Julian double (legacy garbage)

    Returns:
        (:obj:`list` of double): Interpolated data from the timestep in question.
        or (double*): Dataset data at the given timestep.
    """
    if dset.num_times == 1:  # steady-state dataset
        return dset.values[0]

    ts1_idx = 0
    ts2_idx = 0
    ts1 = None
    ts2 = None
    # Set the reference time to offset by based on global time or dset reader's reftime
    ref_time = reftime if dset.ref_time is None else datetime_to_julian(dset.ref_time)
    one_day = datetime.timedelta(days=1)

    for i in range(dset.num_times):
        ts_time = dset.timestep_offset(i) / one_day   # timestep time in days, absolute if dataset has a reftime
        ts_time += ref_time
        if math.isclose(time, ts_time, abs_tol=0.00001):  # found an exact match
            return dset.values[i]
        elif ts_time < time:  # found a timestep before the one we are interested in
            ts1 = ts_time
            ts1_idx = i
        elif ts_time > time:  # found a timestep after the one we are interested in
            # We have either found the two timesteps immediately before and after the time
            #   of interest, or all the dataset's timesteps are after the time of interest
            ts2 = ts_time
            ts2_idx = i
            break

    if not ts1 or not ts2:  # time we are interested in is not in the dataset's range
        return dset.values[-1]  # return the dataset's first or last timestep
    else:
        data2 = dset.values[ts2_idx]
        data1 = dset.values[ts1_idx]
        return interpolate_timesteps(time, data1, data2, ts1, ts2)


def interpolate_timesteps(a_to_time, a_data1, a_data2, a_ts1, a_ts2):
    """Interpolates Data from a selected timestep.

    Args:
        a_to_time (double): The timestep in question.
        a_data1 (double): Data from a_ts1.
        a_data2 (double): Data from a_ts2.
        a_ts1 (double): Timestep just before a_to_time.
        a_ts2 (double): Timestep just after a_to_time.

    Returns:
        (:obj:`list` of double): The interpolated data.

    """
    mult_before = (a_ts2 - a_to_time) / (a_ts2 - a_ts1)
    if mult_before > 1.0:
        mult_before = 1.0
    elif mult_before < 0.0:
        mult_before = 0.0
    mult_after = 1.0 - mult_before
    if isinstance(a_data1[0], list):  # vector dataset
        interp_vals = [[0.0 for _ in range(len(a_data1[0]))] for _ in range(len(a_data1))]
        for i in range(len(a_data1)):
            for j in range(len(a_data1[0])):
                interp_vals[i][j] = (a_data1[i][j] * mult_before) + (a_data2[i][j] * mult_after)
    else:  # scalar dataset
        interp_vals = [0.0 for _ in range(len(a_data1))]
        for i in range(len(a_data1)):
            interp_vals[i] = (a_data1[i] * mult_before) + (a_data2[i] * mult_after)
    return interp_vals


class DatasetWriter:
    """Class for dataset writer."""
    def __init__(self, grid_name, grid, case_times):
        """Constructor.

        Args:
            grid_name (str): Name of the geometry
            grid (CoGrid): The dataset geometry
            case_times (Sequence[float]): The target timestep offsets in seconds
        """
        self.grid_name = grid_name
        self.grid = grid
        self.case_times = case_times
        self.num_i = 0
        self.num_j = 0
        self.ss = StringIO()

    def _reorder_timestep_vals(self, vals):
        """Reorder a timestep so it is first ascending by i and descending by j.

        Args:
            vals (Sequence): The timestep in cell index order

        Returns:
            Sequence: The timestep in i-ascending, j-descending order
        """
        ordered_vals = copy.deepcopy(vals)
        cell_idx = 0
        for j in range(self.num_j, 0, -1):  # j-descending
            for i in range(self.num_i):  # i-ascending
                ordered_vals[cell_idx] = vals[self.grid.get_cell_index_from_ij(i + 1, j)]
                cell_idx += 1
        return ordered_vals

    def write_datadims_list(self, a_num_fields, a_num_recs=-1):
        """Writes The Dataset dimensions and header to file.

        Args:
            a_num_fields (int): Number of Fields
            a_num_recs (int, optional): Number of Records. Defaults to -1 to use number of simulation times.
        """
        i_sizes = self.grid.locations_x
        j_sizes = self.grid.locations_y
        self.num_i = len(i_sizes) - 1
        self.num_j = len(j_sizes) - 1
        d_i = 0.0
        d_j = 0.0
        if len(i_sizes) > 1:
            d_i = i_sizes[1] - i_sizes[0]
        if len(j_sizes) > 1:
            d_j = j_sizes[1] - j_sizes[0]
        if a_num_recs == -1:  # use the number of simulation times
            a_num_recs = len(self.case_times)
        self.ss.write(f"#STWAVE_SPATIAL_DATASET\n&DataDims\n  DataType = 0,\n  NumRecs = {a_num_recs},\n")
        self.ss.write(
            f"  NumFlds = {a_num_fields},\n  NI = {self.num_i},\n  NJ = {self.num_j},\n  DX = {d_i},\n  DY = {d_j},\n"
        )
        self.ss.write(f'  GridName = "{self.grid_name}"\n/\n')

    def write_dataset_list(self, a_field_names, a_field_units, simref_time, a_time_units="hh",
                           a_ind_time=False, a_ref_time=None):
        """Writes data from case_times to file.

        Args:
            a_field_names (:obj:`list` of str): Field Names.
            a_field_units (:obj:`list` of str): Field Units.
            simref_time (datetime.datetime): The reference time.
            a_time_units (str): Time units.
            a_ind_time (bool): Set to True to skip writing RecInc, RecUnits, and Reftime.
            a_ref_time (float): The dataset reference time in seconds
        """
        self.ss.write("&Dataset\n")
        for idx, fld_name in enumerate(a_field_names):
            self.ss.write(f'  FldName({idx + 1}) = "{fld_name}",\n')
        for idx, fld_unit in enumerate(a_field_units):
            self.ss.write(f'  FldUnits({idx + 1}) = "{fld_unit}",\n')
        if not a_ind_time and len(self.case_times) > 1:  # only write out this part if multiple case times
            # I believe we always want the time increment to be 1 so dataset timesteps
            #   are treated as non-regularly spaced.
            # timeInc = self.case_times[1] - self.case_times[0]
            if a_ref_time is None:
                a_ref_time = self.case_times[0]
            dt = simref_time + datetime.timedelta(seconds=float(a_ref_time))
            self.ss.write(f'  RecInc = 1,\n  RecUnits = "{a_time_units}",\n  Reftime = {get_time_string(dt)}\n')
        self.ss.write("/\n")

    def write_data_vals(self, dataset, simref_time, mag_dir=False):
        """Writes data vals from a_dataset to file.

        Args:
            dataset (xms.datasets.dataset_reader.DatasetReader): Data to be written.
            simref_time (datetime.datetime): The reference time.
            mag_dir (bool): True if in magnitude, direction format.
        """
        for time in self.case_times:
            dt = simref_time + datetime.timedelta(seconds=float(time))
            dt2 = simref_time + datetime.timedelta(seconds=float(self.case_times[0]))
            self.ss.write(f"IDD {get_time_string(dt)}\n")
            vals = get_timestep_vals(dataset, datetime_to_julian(dt), datetime_to_julian(dt2))
            vals = self._reorder_timestep_vals(vals)
            if dataset.num_components == 1:  # scalar
                for val in vals:
                    self.ss.write(f"{val:.12f}\n")
            else:  # vector
                for i_idx, i_val in enumerate(vals):
                    for j_idx, j_val in enumerate(i_val):
                        if mag_dir:
                            if j_idx == 0:  # compute mag
                                j_val = math.hypot(j_val, vals[i_idx][1])
                            elif j_idx == 1:  # compute dir and convert to shore normal
                                j_val = math.degrees(math.atan2(j_val, vals[i_idx][0]))
                                j_val = cart_to_shore_normal(j_val, self.grid.angle)
                        self.ss.write(f"{j_val:.12f}  ")
                    self.ss.write("\n")

    # Writes out a transient depth dataset specified by the user. Each timestep of the
    #   dataset is written with its corresponding values. If dataset times do not match
    #   simulation case times, no interpolation is performed. This is how SMS was doing
    #   it, but I am not sure why. The model control options for specifying a depth
    #   dataset as well as the option for a "coupled" depth are not documented by STWAVE.
    def write_depth_file_no_interp(self, dset, simref_time):
        """Writes depth data from given dataset to file.

        Args:
            dset (xms.datasets.dataset_reader.DatasetReader): Data containing depth info.
            simref_time (datetime.datetime): The reference time.
        """
        one_sec = datetime.timedelta(seconds=1)
        self.ss = StringIO()
        self.write_datadims_list(1, dset.num_times)
        ts_time = dset.timestep_offset(0) / one_sec
        self.write_dataset_list(["Depth"], [""], simref_time, "hh", False, ts_time)
        # sms does not interpolate transient depth datasets, we will just write out as is
        dset_null = dset.null_value if dset.null_value is not None else -9999.0
        for i in range(dset.num_times):
            self.ss.write(f"IDD {get_time_string(dset.timestep_offset(i) + simref_time)}\n")
            ts_data = dset.values[i]
            ts_data = self._reorder_timestep_vals(ts_data)
            for val in ts_data:
                depth = -9999.0 if val.item() == dset_null else val.item() * -1.0
                self.ss.write(f"{depth:.12f}\n")  # Depths below sea level are negative in SMS
        out = open(self.grid_name + ".dep.in", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    # writes out a non-transient z dataset for depth. First case time in the simulation
    #   will have a set of z-values
    def write_depth_file(self, zvals, simref_time):
        """Writes depth data from points to file.

        Args:
            zvals (:obj:`list` of float): List of point z values.
            simref_time (datetime.datetime): The reference time.
        """
        self.ss = StringIO()
        self.write_datadims_list(1, 1)
        self.write_dataset_list(["Depth"], [""], simref_time, "hh", True)
        # Non-transient depth, write out z-values once (at first simulation case time)
        cur_time = 0.0 if not self.case_times else float(self.case_times[0])
        try:
            dt = simref_time + datetime.timedelta(seconds=float(cur_time))
        except OverflowError:
            raise OverflowError('Error converting case times to date time value.  Check case time values.')
        self.ss.write(f"IDD {get_time_string(dt)}\n")
        zvals = self._reorder_timestep_vals(zvals)
        for z in zvals:
            depth = z if z == -9999.0 else z * -1.0
            self.ss.write(f"{depth:.12f}\n")  # Depths below sea level are negative in SMS
        out = open(self.grid_name + ".dep.in", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_current_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset to write.
            simref_time (datetime.datetime): The reference time.
        """
        self.ss = StringIO()
        self.write_datadims_list(2)
        self.write_dataset_list(["Current - u", "Current - v"], ["m/s", "m/s"], simref_time)
        self.write_data_vals(a_dataset, simref_time)
        out = open(self.grid_name + ".curr.in", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_surge_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset for surge.
            simref_time (datetime.datetime): The reference time.
        """
        self.ss = StringIO()
        self.write_datadims_list(1)
        self.write_dataset_list(["Surge"], [""], simref_time)
        self.write_data_vals(a_dataset, simref_time)
        out = open(self.grid_name + ".surge.in", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_wind_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset for wind.
            simref_time (datetime.datetime): The reference time.
        """
        self.ss = StringIO()
        self.write_datadims_list(2)
        self.write_dataset_list(["Wind speed", "Wind direction"], ["m/s", "deg"], simref_time)
        self.write_data_vals(a_dataset, simref_time, True)
        out = open(self.grid_name + ".wind.in", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_ice_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset for ice.
            simref_time (datetime.datetime): The reference time.
        """
        self.ss = StringIO()
        self.write_datadims_list(1)
        self.write_dataset_list(["Ice"], [""], simref_time)
        self.write_data_vals(a_dataset, simref_time)
        out = open(self.grid_name + ".ice.in", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    # Friction files are the only datasets SMS exports as time-independent. The
    #   values of the non-transient friction dataset are only written out once
    #   instead of once for each simulation case time (which we do with other
    #   non-transient dataset inputs)
    def write_fric_file(self, dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset for friction.
            simref_time (datetime.datetime): The reference time.
        """
        self.ss = StringIO()
        self.write_datadims_list(1)
        self.write_dataset_list(["Friction"], [""], simref_time, "n/a", True)
        self.ss.write(f"IDD {get_time_string(simref_time)}\n")
        vals = self._reorder_timestep_vals(dataset.values[0])
        for val in vals:
            self.ss.write(f"{val.item():.12f}\n")
        out = open(self.grid_name + ".fric.in", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()


def retrieve_xms_data(query=None, sim_export=False, xms_data=None):
    """Retrieve all the data from XMS needed for export.

    Args:
        query (Optional[Query]): The XMS interprocess communicator. If not provided, implies a simulation export.
            Should be provided if called from the partial export dialog.
        sim_export (bool): exporting entire simulation
        xms_data (dict): Output data dict

    Returns:
        dict: The XMS data dict
    """
    if xms_data is None:
        xms_data = {}
    if query is None:
        query = Query()

    # Get the simulation and its hidden component
    if sim_export is False:  # Partial export dialog, Context is at the simulation component level
        sim_uuid = query.parent_item_uuid()
        sim_comp = query.current_item()
    else:  # Exporting the entire simulation, Context is at the simulation level
        sim_uuid = xms_data.get('sim_uuid', query.current_item_uuid())
        sim_comp = xms_data.get('sim_comp', None)
        if sim_comp is None:
            sim_comp = query.item_with_uuid(sim_uuid, model_name='STWAVE', unique_name='Sim_Component')
    sim_item = tree_util.find_tree_node_by_uuid(query.project_tree, sim_uuid)
    xms_data['sim_name'] = sim_item.name
    xms_data['sim_data'] = sim_comp.main_file
    data = SimulationData(sim_comp.main_file)
    # Get the domain grid
    do_grid = query.item_with_uuid(data.info.attrs['grid_uuid'])
    xms_data['cogrid_file'] = do_grid.cogrid_file

    def _get_spatial_dataset(dset_uuid, dict_key, store_file_info, dset_name):
        """Get a single spatial dataset input.

        Args:
            dset_uuid: The UUID of the spatial dataset.
            dict_key: The key to store the spatial dataset in the dictionary.
            store_file_info: Flag indicating whether to store the file information of the dataset.
            dset_name: The name of the spatial dataset.
        """
        sp_dset = query.item_with_uuid(dset_uuid)
        if not sp_dset:
            logging.getLogger('xms.stwave').error(
                f'Unable to retrieve spatial dataset: {dset_name}. Ensure a dataset has been selected in the '
                'Model Control.'
            )
            return
        if store_file_info:
            xms_data[dict_key] = (sp_dset.h5_filename, sp_dset.group_path)
        else:
            xms_data[dict_key] = sp_dset

    # Get the spatial datasets
    if data.info.attrs['depth'] == const.DEP_OPT_TRANSIENT:
        _get_spatial_dataset(dset_uuid=data.info.attrs['depth_uuid'], dict_key='dep', store_file_info=False,
                             dset_name='Depth')
    if data.info.attrs['current_interaction'] == 'Use dataset':
        _get_spatial_dataset(dset_uuid=data.info.attrs['current_uuid'], dict_key='curr', store_file_info=True,
                             dset_name='Current interaction')
    if data.info.attrs['source_terms'] != const.SOURCE_PROP_ONLY and data.info.attrs['wind'] != const.OPT_CONST:
        _get_spatial_dataset(dset_uuid=data.info.attrs['wind_uuid'], dict_key='wind', store_file_info=True,
                             dset_name='Wind fields')
    if data.info.attrs['surge'] != const.OPT_CONST:
        _get_spatial_dataset(dset_uuid=data.info.attrs['surge_uuid'], dict_key='surge', store_file_info=True,
                             dset_name='Surge fields')
    if data.info.attrs['ice'] == 'Use dataset':
        _get_spatial_dataset(dset_uuid=data.info.attrs['ice_uuid'], dict_key='ice', store_file_info=False,
                             dset_name='Ice fields')
    use_fric = data.info.attrs['friction']
    if use_fric == const.FRIC_OPT_JONSWAP_DSET:
        _get_spatial_dataset(dset_uuid=data.info.attrs['JONSWAP_uuid'], dict_key='fric_JONSWAP', store_file_info=False,
                             dset_name='JONSWAP bottom friction')
    elif use_fric == const.FRIC_OPT_MANNING_DSET:
        _get_spatial_dataset(dset_uuid=data.info.attrs['manning_uuid'], dict_key='fric_manning', store_file_info=False,
                             dset_name='Manning bottom friction')
    return xms_data


def export_all_datasets(xms_data):
    """Export all datasets.

    Args:
        xms_data (dict): The XMS data JSON dict passed on the command line
    """
    logger = logging.getLogger('xms.stwave')
    if not xms_data:  # If we aren't testing, instantiate a Query and retrieve the data from SMS.
        logger.info('Retrieving spatial datasets from SMS...')
        xms_data = retrieve_xms_data()

    sim_mainfile = xms_data.get('sim_data')
    if not sim_mainfile:
        raise RuntimeError('Unable to load simulation data for spatial dataset export script.')
    data = SimulationData(sim_mainfile)

    sim_name = xms_data.get('sim_name', '')
    # Clear out files from previous runs if they exist
    logger.info('Cleaning old files...')
    for ext in ['cur', 'dep', 'fric', 'fref', 'bref', 'eta', 'mud', 'wind']:
        xfs.removefile(os.path.join(os.getcwd(), f'{sim_name}.{ext}'))

    grid_name = xms_data['sim_name']

    # get case times for simulation as second offsets from the simulation reftime
    times = data.times_in_seconds().tolist()

    # get the simulation grid
    logger.info('Reading domain grid from SMS...')
    sim_grid = read_grid_from_file(xms_data['cogrid_file'])
    locations_x = sim_grid.locations_x
    locations_y = sim_grid.locations_y

    d_set_exporter = DatasetWriter(grid_name, sim_grid, times)

    # write the depth file
    depth_source = data.info.attrs['depth']
    reftime = datetime.datetime.strptime(data.info.attrs['reftime'], ISO_DATETIME_FORMAT)
    if depth_source == const.DEP_OPT_TRANSIENT:
        logger.info('Writing transient depth dataset...')
        dep_d_set = xms_data['dep']
        depth_reftime = dep_d_set.ref_time if dep_d_set.ref_time is not None else reftime
        d_set_exporter.write_depth_file_no_interp(dep_d_set, depth_reftime)
    elif depth_source == const.DEP_OPT_NONTRANSIENT:  # use z as depth
        logger.info('Writing non-transient depth dataset...')
        cell_elevations = sim_grid.cell_elevations
        cell_z = []
        if sim_grid.numbering == 0:
            i_size = len(locations_x) - 1
            j_size = len(locations_y) - 1
            for j in range(j_size):
                for i in range(i_size):
                    cell_z.append(cell_elevations[i * i_size + j])
        else:
            cell_z = list(cell_elevations)
        d_set_exporter.write_depth_file(cell_z, reftime)
    else:  # depth_source == "102", coupled?
        pass

    # write the current file
    dset = xms_data.get('curr')
    if dset:
        logger.info('Writing current dataset...')
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_current_file(reader, reftime)

    # wind data
    terms = data.info.attrs['source_terms']
    if terms != const.SOURCE_PROP_ONLY:  # not propagation only
        dset = xms_data.get('wind')
        if dset:
            logger.info('Writing wind dataset...')
            reader = DatasetReader(dset[0], group_path=dset[1])
            d_set_exporter.write_wind_file(reader, reftime)

    # write surge file
    use_const_tidal = data.info.attrs['surge']
    if use_const_tidal != const.OPT_CONST:  # variable tidal surge values
        dset = xms_data.get('surge')
        if dset:
            logger.info('Writing surge dataset...')
            reader = DatasetReader(dset[0], group_path=dset[1])
            d_set_exporter.write_surge_file(reader, reftime)

    # write ice file
    use_ice = data.info.attrs['ice']
    if use_ice == "Use dataset":
        ice_d_set = xms_data.get('ice')
        if ice_d_set:
            logger.info('Writing ice dataset...')
            d_set_exporter.write_ice_file(ice_d_set, reftime)

    # write friction file
    use_fric = data.info.attrs['friction']
    if use_fric == const.FRIC_OPT_JONSWAP_DSET:
        fric_d_set = xms_data.get('fric_JONSWAP')
        if fric_d_set:
            logger.info('Writing JONSWAP friction dataset...')
            # to interpolate or to not interpolate? same with depth
            d_set_exporter.write_fric_file(fric_d_set, reftime)
    elif use_fric == const.FRIC_OPT_MANNING_DSET:
        d_set_manning = xms_data.get('fric_manning')
        if d_set_manning:
            logger.info('Writing Manning friction dataset...')
            d_set_exporter.write_fric_file(d_set_manning, reftime)
