"""Module to interpolate datasets through time."""

# 1. Standard Python modules
import copy
import datetime
import math

# 2. Third party modules

# 3. Aquaveo modules
from xms.data_objects.parameters import datetime_to_julian

# 4. Local modules


# Adapted from the dataset interpolation code in SMS
class DatasetInterpolator:
    """Class to interpolate a dataset through time given a target time increment and length."""
    def __init__(self, dset, reftime, timeinc, rundays, global_time):
        """Construct the interpolator.

        Args:
            dset (:obj:`xms.datasets.dataset_reader.DatasetReader`): The dataset to interpolate
            reftime (:obj:`float`): The simulation interpolation reference date (start of the ADCIRC run) as a Julian
            timeinc (:obj:`float`): The target timestep size in seconds
            rundays (:obj:`float`): The target time length to interpolate the dataset in days
            global_time (:obj:`datetime.datetime`): The SMS global time. Will be used as the dataset's reference time if
                one is not defined for it already.
        """
        self._one_day = datetime.timedelta(days=1)
        self._dset = dset
        self._dset_firsttime = self._dset.timestep_offset(0) / self._one_day  # Convert offsets to total days
        self._dset_lasttime = self._dset.timestep_offset(-1) / self._one_day
        self._dset_reftime = 0.0
        if self._dset.ref_time is not None:
            self._dset_reftime = datetime_to_julian(self._dset.ref_time)  # Use reference time defined in dataset
        elif global_time is not None:
            self._dset_reftime = datetime_to_julian(global_time)  # Use XMS global zero time
        else:
            raise ValueError('Must provide a global zero time if dataset reference date is undefined.')
        self._dset_firsttime = self._dset_firsttime + self._dset_reftime
        self._dset_lasttime = self._dset_lasttime + self._dset_reftime
        self._sim_reftime = reftime
        self._sim_lasttime = 0.0
        self._target_times = []
        self._initialize_target_times(timeinc, rundays)
        self._trivial_data = self._get_trivial_case()  # Check if this a trivial case

    @staticmethod
    def _compute_time_weight(before, after, target):
        """Compute the time weighting factor for a target timestep, given times for timesteps beofre and after it.

        Args:
            before (:obj:`float`): Timestep time of the previous timestep
            after (:obj:`float`): Timestep time of the next timestep
            target (:obj:`float`): Timestep time of the target timestep

        Returns:
            (:obj:`tuple`): Tuple of two floats. First is multiplier for previous timestep's value, second is
            multiplier for next timestep's value
        """
        mult_before = (after - target) / (after - before)
        if mult_before > 1.0:
            mult_before = 1.0
        elif mult_before < 0.0:
            mult_before = 0.0
        return mult_before, 1.0 - mult_before

    def _initialize_target_times(self, time_increment_secs, run_duration_days):
        """Build up the list of target dataset times.

        Args:
            time_increment_secs (:obj:`float`): The timestep increment size in secounds
            run_duration_days (:obj:`float`): The total run time duration in days
        """
        # Convert time increment to days for offsetting reference Julian date
        time_increment_days = datetime.timedelta(seconds=time_increment_secs) / self._one_day
        run_duration_seconds = datetime.timedelta(days=run_duration_days) / datetime.timedelta(seconds=1)

        self._sim_lasttime = self._sim_reftime
        elapsed_seconds = 0.0
        self._target_times.append(self._sim_reftime)  # start at the reftime
        while elapsed_seconds < run_duration_seconds:
            self._sim_lasttime += time_increment_days
            self._target_times.append(self._sim_lasttime)  # build up a list of the target times
            elapsed_seconds += time_increment_secs

    def _get_trivial_case(self):
        """Get the dataset values to use for all timesteps if a tivial case.

        A trivial case would be when none of the dataset timesteps overlap the target time range.

        Returns:
            (:obj:`list`): The dataset values to use for all timesteps. None if not a trivial case.
        """
        # check if all timesteps occur before the simulation run (use the last for all)
        if self._sim_reftime >= self._dset_lasttime:
            return self._dset.values[-1]
        # check if all timesteps occur after the simulation run (use first for all)
        elif self._sim_lasttime <= self._dset_firsttime:
            return self._dset.values[0]
        # check if the dataset is steady state
        elif self._dset.num_times < 2:
            return self._dset.values[0]
        return None

    def _get_interp_timestep_indices(self, target_time):
        """Get the indices of the timesteps before and after a target timestep.

        Args:
            target_time (:obj:`float`): Time of the target timestep as a Julian double

        Returns:
            (:obj:`tuple`): Tuple of two floats. First is index of the previous timestep (-1 if target is before
            start of dataset timesteps), second is index of the next timestep (-1 if target is after end of
            dataset timesteps).
        """
        # Check if the dataset starts after the target - return first idx
        if target_time < self._dset_firsttime:
            return 0, -1

        # Check if the dataset ends before the target - return last idx
        num_times = self._dset.num_times
        if target_time > self._dset_lasttime:
            return -1, num_times - 1

        # Find the timesteps before and after the target
        for i in range(num_times):
            dset_time = self._dset.timestep_offset(i) / self._one_day  # Convert timestep offset to days
            dset_time = dset_time + self._dset_reftime  # Add timestep offset to reference date
            if math.isclose(target_time, dset_time, abs_tol=1.1574074074074e-8):  # matching within 0.001 seconds
                return i, i
            # Passed the target, find out if this or the previous was closer
            if target_time < dset_time:
                return i - 1, i

        # Didn't pass the target - return the last idx
        return -1, num_times - 1

    def get_num_timesteps(self):
        """Get the number of target timesteps we will be interpolating to."""
        return len(self._target_times)

    def get_target_time(self, idx):
        """Get a target timestep time by index."""
        return self._target_times[idx]

    def interpolate_timestep(self, ts_idx):
        """Get the time-interpolated values for a single target timestep.

        Args:
            ts_idx (:obj:`int`): The index of the target timestep to interpolate to

        Returns:
            (:obj:`list`): Values of the target timestep. If vector, list of lists where each inner list contains
            an x and y component value.
        """
        if self._trivial_data is not None:  # Short-circuit trivial cases
            return copy.deepcopy(self._trivial_data)

        target_time = self._target_times[ts_idx]
        first_idx, last_idx = self._get_interp_timestep_indices(target_time)
        if first_idx == last_idx:  # exact match, no interpolation necessary
            idx = 0 if first_idx < 0 else first_idx
            return self._dset.values[idx]
        elif last_idx == -1:  # target is before start of dataset timesteps, use the first
            return self._dset.values[0]
        elif first_idx == -1:  # target is after end of dataset timesteps, use the last
            return self._dset.values[last_idx]
        else:  # interpolate
            ts1_time = self._dset.timestep_offset(first_idx) / self._one_day  # Convert timestep time offsets to days
            ts2_time = self._dset.timestep_offset(last_idx) / self._one_day
            ts1_time = ts1_time + self._dset_reftime  # Add timestep offsets to reference time
            ts2_time = ts2_time + self._dset_reftime
            ts1_data = self._dset.values[first_idx]
            ts2_data = self._dset.values[last_idx]
            mult_before, mult_after = self._compute_time_weight(ts1_time, ts2_time, target_time)

            numcomps = self._dset.num_components
            if numcomps == 1:  # Fill 1-D array
                ts_vals = [(before * mult_before) + (after * mult_after) for before, after in zip(ts1_data, ts2_data)]
            else:  # Fill 2D array
                ts_vals = [
                    [(before[comp] * mult_before) + (after[comp] * mult_after) for comp in range(numcomps)]
                    for before, after in zip(ts1_data, ts2_data)
                ]
            return ts_vals
