"""A writer for CMS-Wave spatial datasets."""

# 1. Standard Python modules
import datetime
from io import StringIO
import math
import os
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint import read_grid_from_file
from xms.core.filesystem import filesystem as xfs
from xms.data_objects.parameters import datetime_to_julian
from xms.datasets.dataset_reader import DatasetReader
from xms.guipy.time_format import string_to_datetime

# 4. Local modules
from xms.cmswave.data.simulation_data import convert_time_into_seconds, SimulationData

# -------------------------------------------------------------------------------
# Code adapted from WriteDatasetFile class in SMS
# -------------------------------------------------------------------------------


# convert cartesian convention angles to shore normal (-180 to 180 degrees)
def cart_to_shore_normal(a_from, a_to):
    """Converts cartesian angle to shore-normal.

    Args:
        a_from (:obj:`float`): Starting angle in degrees.
        a_to (:obj:`float`): Ending angle in degrees.

    Returns:
        (:obj:`float`): The resulting degree angle.

    """
    a_from -= a_to
    while a_from > 180.0:
        a_from -= 360.0
    while a_from < -180.0:
        a_from += 360.0
    return a_from


# convert meteorologic convention angles to shore normal (-180 to 180 degrees)
def meteor_to_shore_normal(a_from, a_to):
    """Converts meteorological angle to shore-normal.

    Args:
        a_from (:obj:`float`): Starting angle in degrees.
        a_to (:obj:`float`): Ending angle in degrees.

    Returns:
        (:obj:`float`): The resulting degree angle.

    """
    a_from = 270.0 - a_from
    return cart_to_shore_normal(a_from, a_to)


# convert oceanographic convention angles to shore normal (-180 to 180 degrees)
def ocean_to_shore_normal(a_from, a_to):
    """Converts oceanographic angle to shore-normal.

    Args:
        a_from (:obj:`float`): Starting angle in degrees.
        a_to (:obj:`float`): Ending angle in degrees.

    Returns:
        (:obj:`float`): The resulting degree angle.
    """
    a_from = 90.0 - a_from
    return cart_to_shore_normal(a_from, a_to)


# get time in 'yyyymmddhhmm' format
def get_time_string(a_time, date_format='12 digits'):
    """Gets the formatted string version of a time object.

    Args:
        a_time (:obj:`datetime.datetime`): The time object to reference.
        date_format (:obj:`str`): The time format, either '8 digits' or '12 digits'

    Return:
        (:obj:`str`): The formatted string showing time.

    """
    # CMS-Wave only keeps minute resolution, round to the nearest one
    hours = a_time.time().hour
    minutes = a_time.time().minute
    seconds = a_time.time().second
    if seconds >= 30:
        minutes += 1
        if minutes >= 60:
            hours += 1
            minutes = 0
            if hours >= 24:  # truncate instead if rounding pushes us into another day
                hours = a_time.time().hour
                minutes = a_time.time().minute

    if date_format == '12 digits':
        return f"{a_time.date().year}{a_time.date().month:0>2d}{a_time.date().day:0>2d}{hours:0>2d}{minutes:0>2d}"
    elif date_format == '8 digits':
        year = a_time.date().year
        year = year - 2000 if year > 2000 else year - 1900
        return f"{year:0>2d}{a_time.date().month:0>2d}{a_time.date().day:0>2d}{hours:0>2d}"
    else:
        raise ValueError('Wrong date format')


# finds most appropriate timestep for given time, interpolating if needed
def get_timestep_vals(dset, time, reftime):  # a_time and a_ref_time are Julian doubles
    """Gets the data from a particular timestep in a given dataset.

    Args:
        dset (:obj:`DatasetReader`): Dataset to reference
        time (:obj:`double`): Timestep to get data from
        reftime (:obj:`double`): Reference time value in case the dataset has none.

    Returns:
        (:obj:`list[double]`): Interpolated data from the timestep in question.
        or (:obj:`double*`): Dataset data at the given timestep.

    """
    if dset.num_times == 1:  # steady-state dataset
        return dset.values[0]

    ts1_idx = 0
    ts2_idx = 0
    ts1 = None
    ts2 = None
    # Set the reference time to offset by based on global time or dset reader's reftime
    ref_time = reftime if dset.ref_time is None else datetime_to_julian(dset.ref_time)
    one_day = datetime.timedelta(days=1)

    for i in range(dset.num_times):
        ts_time = dset.timestep_offset(i) / one_day  # timestep time in days, absolute if dataset has a reftime
        ts_time += ref_time
        if math.isclose(time, ts_time, abs_tol=0.00001):  # found an exact match
            return dset.values[i]
        elif ts_time < time:  # found a timestep before the one we are interested in
            ts1 = ts_time
            ts1_idx = i
        elif ts_time > time:  # found a timestep after the one we are interested in
            # We have either found the two timesteps immediately before and after the time
            #   of interest, or all the dataset's timesteps are after the time of interest
            ts2 = ts_time
            ts2_idx = i
            break

    if not ts1 or not ts2:  # time we are interested in is not in the dataset's range
        return dset.values[-1]  # return the dataset's first or last timestep
    else:
        data2 = dset.values[ts2_idx]
        data1 = dset.values[ts1_idx]
        return interpolate_timesteps(time, data1, data2, ts1, ts2)


def interpolate_timesteps(a_to_time, a_data1, a_data2, a_ts1, a_ts2):
    """Interpolates Data from a selected timestep.

    Args:
        a_to_time (:obj:`double`): The timestep in question.
        a_data1 (:obj:`double`): Data from a_ts1.
        a_data2 (:obj:`double`): Data from a_ts2.
        a_ts1 (:obj:`double`): Timestep just before a_to_time.
        a_ts2 (:obj:`double`): Timestep just after a_to_time.

    Returns:
        (:obj:`list[double]`): The interpolated data.

    """
    mult_before = (a_ts2 - a_to_time) / (a_ts2 - a_ts1)
    if mult_before > 1.0:
        mult_before = 1.0
    elif mult_before < 0.0:
        mult_before = 0.0
    mult_after = 1.0 - mult_before
    if isinstance(a_data1[0], list):  # vector dataset
        interp_vals = [[0.0 for _ in range(len(a_data1[0]))] for _ in range(len(a_data1))]
        for i in range(len(a_data1)):
            for j in range(len(a_data1[0])):
                interp_vals[i][j] = (a_data1[i][j] * mult_before) + (a_data2[i][j] * mult_after)
    else:  # scalar dataset
        interp_vals = [0.0 for _ in range(len(a_data1))]
        for i in range(len(a_data1)):
            interp_vals[i] = (a_data1[i] * mult_before) + (a_data2[i] * mult_after)
    return interp_vals


class CmsWaveDatasetWriter:
    """Class for dataset writer."""
    def __init__(self, grid_name, grid, case_times, time_units):
        """Constructor."""
        self.grid_name = grid_name
        self.grid = grid
        self.case_times = case_times
        self.time_units = time_units
        self.locations_x = None
        self.locations_y = None
        self.write_cell_sizes = False
        self.ss = StringIO()

    def write_datadims(self):
        """Writes The Dataset dimensions and header to file."""
        if self.locations_x is None:  # Only unwrap these once
            self.locations_x = self.grid.locations_x
            self.locations_y = self.grid.locations_y
        d_i = 0.0
        d_j = 0.0
        mismatch = False
        if len(self.locations_x) > 1:
            d_i = self.locations_x[1] - self.locations_x[0]
            for i in range(1, len(self.locations_x) - 1):
                if not math.isclose(self.locations_x[i + 1] - self.locations_x[i], d_i, abs_tol=0.00001):
                    mismatch = True
                    break
        if len(self.locations_y) > 1:
            d_j = self.locations_y[1] - self.locations_y[0]
            if not mismatch:
                for i in range(1, len(self.locations_y) - 1):
                    if not math.isclose(self.locations_y[i + 1] - self.locations_y[i], d_j, abs_tol=0.00001):
                        mismatch = True
                        break
        if mismatch:  # If not square cells, export 999.0 for DJ to indicate it is variable
            self.write_cell_sizes = True  # If the .dep file, will write another section with the i-j cell sizes
            d_j = 999.0
        self.ss.write(f'{len(self.locations_x) - 1} {len(self.locations_y) - 1} {d_i:.6f} {d_j:.6f}\n')

    def write_data_vals(self, dataset, simref_time, write_date=False, num_vals_per_line=5):
        """Writes data vals from a_dataset to file.

        Args:
            dataset (:obj:`xms.datasets.dataset_reader.DatasetReader`): Data to be written.
            simref_time (:obj:`datetime.datetime`): The global time.
            write_date (:obj:`bool`): True if a timestamp is written, False otherwise.
            num_vals_per_line (:obj:`int`): Number of values to write per line
        """
        for time in self.case_times:
            t0_sec = convert_time_into_seconds(self.time_units, float(time))
            t1_sec = convert_time_into_seconds(self.time_units, float(self.case_times[0]))
            dt = simref_time + datetime.timedelta(seconds=t0_sec)
            dt2 = simref_time + datetime.timedelta(seconds=t1_sec)
            if write_date:
                self.ss.write(f"    {get_time_string(dt)}\n")
            vals = get_timestep_vals(dataset, datetime_to_julian(dt), datetime_to_julian(dt2))
            if dataset.num_components == 1:  # scalar
                num_added = 0
                # for val in vals:
                for j in range(len(self.grid.locations_y) - 2, -1, -1):
                    for ii in range(0, len(self.locations_x) - 1):
                        val = vals[j * (len(self.locations_x) - 1) + ii]
                        self.ss.write(f' {val:12.6f}')
                        num_added += 1
                        if num_added % (len(self.locations_x) - 1) == 0:  # Newline at end of each i-column
                            # Special case for surge
                            self.ss.write('\n')
                            num_added = 0
                        elif num_added % num_vals_per_line == 0:  # Newline at specified number of values
                            self.ss.write('\n')
            else:  # vector
                num_added = 0
                for j in range(len(self.grid.locations_y) - 2, -1, -1):
                    for ii in range(0, len(self.locations_x) - 1):
                        cell_id = j * (len(self.locations_x) - 1) + ii
                        x_val = vals[cell_id][0]
                        y_val = vals[cell_id][1]

                        # Transform the point for writing out
                        cosangle = math.cos(self.grid.angle * math.pi / 180.0)
                        sinangle = math.sin(self.grid.angle * math.pi / 180.0)
                        x_old = x_val
                        x_val = x_val * cosangle + y_val * sinangle
                        y_val = x_old * -sinangle + y_val * cosangle

                        self.ss.write(f'{x_val:14f} ')
                        num_added += 1
                        if num_added % num_vals_per_line == 0:
                            self.ss.write('\n')

                        self.ss.write(f'{y_val:14f} ')
                        num_added += 1
                        if num_added % ((len(self.locations_x) - 1) * 2) == 0:
                            self.ss.write('\n')
                            num_added = 0
                        elif num_added % num_vals_per_line == 0:
                            self.ss.write('\n')
            if not write_date:
                break  # If we are not a transient dataset type, only write one timestep

    def write_depth_file(self, zvals, x_sizes, y_sizes):
        """Writes depth data from points to file.

        Args:
            zvals (:obj:`list[float]`): List of point z values.
            x_sizes (:obj:`int`): x sizes
            y_sizes (:obj:`int`): y sizes
        """
        self.ss = StringIO()
        self.write_datadims()

        for j in range(y_sizes - 1, -1, -1):
            count = 0
            for i in range(x_sizes):
                self.ss.write(f' {-1.0 * zvals[j * x_sizes + i]:>12.6f}')
                count += 1
                if count % 5 == 0:
                    self.ss.write('\n')
            if count % 5 != 0:
                self.ss.write('\n')
        self.write_ij_sizes()

        out = open(self.grid_name + ".dep", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_ij_sizes(self):
        """Write the cell i-j sizes to the file if a .dep file with non-square cells."""
        if self.write_cell_sizes:
            count = 0
            if count % 5 != 0:
                self.ss.write('\n')
            count = 0
            for i in range(len(self.locations_x) - 1):
                self.ss.write(f' {self.locations_x[i + 1] - self.locations_x[i]:>12.6f}')
                count += 1
                if count % 5 == 0:
                    self.ss.write('\n')
            if count % 5 != 0:
                self.ss.write('\n')
            count = 0
            for i in range(len(self.locations_y) - 1):
                self.ss.write(f' {self.locations_y[i + 1] - self.locations_y[i]:>12.6f}')
                count += 1
                if count % 5 == 0:
                    self.ss.write('\n')

    def write_current_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset to write.
            simref_time (:obj:`datetime.datetime`): The global time.
        """
        self.ss = StringIO()
        self.write_datadims()
        self.write_data_vals(a_dataset, simref_time, write_date=True)
        out = open(self.grid_name + ".cur", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_surge_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset for surge.
            simref_time (:obj:`datetime.datetime`): The global time.

        """
        self.ss = StringIO()
        self.write_datadims()
        self.write_data_vals(a_dataset, simref_time, write_date=True, num_vals_per_line=10)
        out = open(self.grid_name + ".eta", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_wind_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset for wind.
            simref_time (:obj:`datetime.datetime`): The global time.

        """
        self.ss = StringIO()
        self.write_datadims()
        self.write_data_vals(a_dataset, simref_time, write_date=True)
        out = open(self.grid_name + ".wind", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_fric_file(self, dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset for friction.
            simref_time (:obj:`datetime.datetime`): The global time.

        """
        self.ss = StringIO()
        self.write_datadims()
        self.write_data_vals(dataset, simref_time)
        out = open(self.grid_name + ".fric", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_backward_reflection_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset to write.
            simref_time (:obj:`datetime.datetime`): The global time.
        """
        self.ss = StringIO()
        self.write_datadims()
        self.write_data_vals(a_dataset, simref_time)
        out = open(self.grid_name + ".bref", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_forward_reflection_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset to write.
            simref_time (:obj:`datetime.datetime`): The global time.
        """
        self.ss = StringIO()
        self.write_datadims()
        self.write_data_vals(a_dataset, simref_time)
        out = open(self.grid_name + ".fref", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()

    def write_mud_file(self, a_dataset, simref_time):
        """Writes data dimensions, dataset list, and data vals to file.

        Args:
            a_dataset (:obj:`data_objects.Parameters.Spatial.SpatialVector.Dataset`): Dataset to write.
            simref_time (:obj:`datetime.datetime`): The global time.
        """
        self.ss = StringIO()
        self.write_datadims()
        self.write_data_vals(a_dataset, simref_time)
        out = open(self.grid_name + ".mud", "w")
        self.ss.seek(0)
        shutil.copyfileobj(self.ss, out)
        out.close()


def export_all_datasets(xms_data):
    """Export all datasets.

    Args:
        xms_data (:obj:`dict`): The XMS data JSON dict passed on the command line
    """
    sim_mainfile = xms_data.get('sim_data')
    if not sim_mainfile:
        raise RuntimeError('Unable to load simulation data for spatial dataset export script.')
    data = SimulationData(sim_mainfile)

    sim_name = xms_data.get('sim_name', '')
    # Clear out files from previous runs if they exist
    for ext in ['cur', 'dep', 'fric', 'fref', 'bref', 'eta', 'mud', 'wind']:
        xfs.removefile(os.path.join(os.getcwd(), f'{sim_name}.{ext}'))

    # get case times for simulation
    times = data.case_times['Time']
    time_units = data.info.attrs['reftime_units']

    # get the simulation grid
    sim_grid = read_grid_from_file(xms_data['cogrid_file'])
    locations_x = sim_grid.locations_x
    locations_y = sim_grid.locations_y

    d_set_exporter = CmsWaveDatasetWriter(sim_name, sim_grid, times, time_units)
    d_set_exporter.locations_x = locations_x
    d_set_exporter.locations_y = locations_y

    reftime = string_to_datetime(data.info.attrs['reftime'])

    # Write the depth file
    cell_elevations = sim_grid.cell_elevations
    pts_z = []
    if sim_grid.numbering == 0:
        i_size = len(locations_x) - 1
        j_size = len(locations_y) - 1
        for j in range(j_size):
            for i in range(i_size):
                pts_z.append(cell_elevations[i * i_size + j])
    else:
        pts_z = cell_elevations
    d_set_exporter.write_depth_file(pts_z, len(locations_x) - 1, len(locations_y) - 1)

    # write the current file
    dset = xms_data.get('curr')
    if dset:
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_current_file(reader, reftime)

    # wind data
    dset = xms_data.get('wind')
    if dset:
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_wind_file(reader, reftime)

    # write surge file
    dset = xms_data.get('surge')
    if dset:
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_surge_file(reader, reftime)

    # write friction file
    dset = xms_data.get('friction')
    if dset:
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_fric_file(reader, reftime)

    # write forward reflection file
    dset = xms_data.get('forward_reflection')
    if dset:
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_forward_reflection_file(reader, reftime)

    # write backward reflection file
    dset = xms_data.get('backward_reflection')
    if dset:
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_backward_reflection_file(reader, reftime)

    # write muddy bed file
    dset = xms_data.get('muddy_bed')
    if dset:
        reader = DatasetReader(dset[0], group_path=dset[1])
        d_set_exporter.write_mud_file(reader, reftime)
