"""Xy series utilities."""

__copyright__ = "(C) Copyright Aquaveo 2022"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
from typing import Sequence

# 2. Third party modules

# 3. Aquaveo modules

# 4. Local modules
from xms.coverage.xy.xy_series import XySeries

# Type aliases
XySeriesDict = dict[int, XySeries]
XyValues = tuple[Sequence[int | float], Sequence[int | float]]


def clamp(n: float, minn: float, maxn: float) -> float:
    """Returns the value n constrained between minn and maxn.

    Args:
        n: The value.
        minn: The minimum.
        maxn: The maximum.

    Returns:
        See description.
    """
    if n < minn:
        return minn
    elif n > maxn:
        return maxn
    else:
        return n


def scale_y_data(xy_series: XySeries, factor: float) -> None:
    """Scales the y data of the curve using 'factor'.

    Args:
        xy_series: The xy_series.
        factor (float): The factor.
    """
    for i, y in enumerate(xy_series.y):
        xy_series.y[i] = y * factor


def find_match(xy_series_dict: XySeriesDict, xy_series: XySeries) -> int | None:
    """Looks for an XySeries in xy_series_dict that matches xy_series, comparing only x and y.

    Args:
        xy_series_dict: Dict of XySeries, keyed by their ID.
        xy_series: An XySeries.

    Returns:
        (int|None): ID of a matching XySeries in xy_series_dict, or None.
    """
    for series_id, dict_series in xy_series_dict.items():
        if dict_series == xy_series:
            return series_id
    return None


def next_unique_id(xy_series_dict: XySeriesDict) -> int:
    """Returns a unique ID number that is one more than the maximum ID already in use.

    Args:
        xy_series_dict: Dict of XySeries, keyed by their ID.

    Returns:
        (int): An ID not in xy_series_dict.
    """
    if not xy_series_dict:
        return 1
    return max(xy_series_dict.keys()) + 1


def add_or_match(new_series: XySeries, xy_series_dict: XySeriesDict) -> int:
    """Adds new_series to xy_series_dict if an identical series isn't already there, otherwise returns id of match.

    Considers only x and y when looking for a match. The ID

    Args:
        new_series: An XySeries.
        xy_series_dict: Dict of XySeries, keyed by their ID.

    Returns:
        (int): ID of new_series or of a preexisting matching series.
    """
    id_ = find_match(xy_series_dict, new_series)
    if id_ is None:
        new_series.series_id = next_unique_id(xy_series_dict)
        xy_series_dict[new_series.series_id] = new_series
        id_ = new_series.series_id
    return id_


def create_constant_xy_series(name: str, constant: float, factor: float) -> XySeries:
    """Creates and returns an XySeries with one xy pair where y = constant * factor.

    Args:
        name: Name to give to new XySeries.
        constant: The constant.
        factor: A factor used to scale constant.

    Returns:
        (XySeries): The new XySeries.
    """
    return XySeries(x=[0.0], y=[constant * factor], name=name, series_id=1)


def interpolate(
    xy1: XySeries,
    xy2: XySeries | None = None,
    constant: float | None = None,
    t: float = 0.5,
    xy_dict: XySeriesDict | None = None
) -> XySeries | None:
    """Interpolates an xy series with another xy series or a constant and returns the resulting xy series.

        ::

            t is the parametric distance between ts1 and ts2.
            0.0<--------------------------------------->1.0
            |<----------- t ------------>|              |
            (xy1)                     (new_xy)     (xy2 or constant)

    The new timeseries is computed as follows:
    1. New xs are created at the same points as the xs in both of the old xy series. (i.e. a union of the xs)
    2. A value is determined from each of the old xy series for each of the new xs.
    3. The two values are linearly interpolated to obtain a result for the specific x.

    Args:
        xy1: First xy series.
        xy2: Second xy series. Specify either this or constant, not both.
        constant: A constant value. Specify either this or xy2, not both.
        t: Value between 0.0 and 1.0 indicating distance between xy1 and xy2.
        xy_dict: If provided, new series will get added to this with a unique ID.

    Returns:
        XySeries or None if errors.
    """
    if (xy2 is not None and constant is not None) or (xy2 is None and constant is None):
        logger = logging.getLogger('xms.coverage')
        logger.error('xy_util.interpolate(): Provide either "xy2" or "constant" but not both.')
        return None

    t = clamp(t, 0.0, 1.0)
    u = 1.0 - t
    if xy2 is not None:
        new_name = f'{xy1.name}-{xy2.name}-{t}'
        new_x = sorted(list(set(list(xy1.x) + list(xy2.x))))  # list(xy1.x) etc. is needed when it's a np.array
        y1_all = y_from_x(xy1, new_x)
        y2_all = y_from_x(xy2, new_x)
        new_y = [y1 * u + y2 * t for y1, y2 in zip(y1_all, y2_all)]
    else:  # constant is not None
        new_name = f'{xy1.name}-{constant}-{t}'
        new_x = xy1.x.copy()
        new_y = [y * u + constant * t for y in xy1.y]
    new_series = XySeries(x=new_x, y=new_y, name=new_name, series_id=1, use_dates_times=xy1.use_dates_times)
    if xy_dict:
        new_series.series_id = next_unique_id(xy_dict)
        xy_dict[new_series.series_id] = new_series
    return new_series


def y_from_x(xy_series: XySeries, x_values: Sequence[int | float]) -> list[float]:
    """Return the y values at the given x values, interpolating and extrapolating.

    See CTimeSeriesCurve::ValueFromTime() in XMS.

    Args:
        xy_series: The xy series.
        x_values: x values we want to get y values at.

    Returns:
        (list[float]): y values, size of x_values.
    """
    y_values = []
    x = xy_series.x
    y = xy_series.y
    len_x = len(x)
    for x_value in x_values:
        if x_value < x[0]:
            y_values.append(y[0])
        elif x_value > x[-1]:
            y_values.append(y[-1])
        elif len_x == 1:
            y_values.append(y[0])
        else:
            x_idx = 1
            while x_idx < len_x and x[x_idx] < x_value:
                x_idx += 1
            delta_y = (y[x_idx] - y[x_idx - 1])
            t = (x_value - x[x_idx - 1]) / (x[x_idx] - x[x_idx - 1])
            y_values.append(y[x_idx - 1] + (t * delta_y))
    return y_values


def get_step_function(x_values: Sequence[int | float], y_values: Sequence[int | float]) -> XyValues:
    r"""Returns new values making it a step function.

        ::

                  Before                          After

                      *                             *---*
                    /   \                           |   |
          y       *       *                     *---*   *
          |
          *--x

    Args:
        x_values: The x values.
        y_values: The y values

    Return:
        (tuple[list[any], list[any]]): New values as a step function.
    """
    if not x_values or len(x_values) < 2:
        return x_values, y_values

    new_x = []
    new_y = []
    for i, (x, y) in enumerate(zip(x_values, y_values)):
        new_x.append(x)
        new_y.append(y)
        if i < len(x_values) - 1:
            new_x.append(x_values[i + 1])
            new_y.append(y)
    return new_x, new_y


def average_y_from_x_range(xy_series: XySeries, start_x, end_x) -> tuple[float, bool]:
    r"""Returns a tuple: value from a xy series that is the average y over the given x range, and extrapolation flag.

        See xyAverageValueFromTimeRange in XMS.

        This average value is calculated by computing the area beneath the curve and dividing it by the time interval
        length. All lines are assumed to be constant value before the first and after the last point.

        EXAMPLE Given the timeseries below, what is the value from the range [0,20]
             5,5 _____ 10,5
                /     \
               /       \  13,2
              /         \_____ 18,2
             /                \
        0,0 /                  \ 20,0

        Solution:
           Area = 60 85
           value = 60/(20-0) = 3 4.25

    Args:
        xy_series: The xy_series.
        start_x: Starting x value for the range.
        end_x: Ending x value for the range.

    Returns:
        The average value and a bool that is true if extrapolation was done.
    """
    flag = False
    num_pts = xy_series.count()  # for short
    if num_pts == 0:
        flag = True
        return -1.0, flag

    if start_x >= end_x:
        return y_from_x(xy_series, [(start_x + end_x) / 2.0])[0], flag

    x = xy_series.x  # for short
    y = xy_series.y  # for short
    first_found = False
    area = 0.0
    last_x = start_x
    last_y = y[0]

    # Handle stuff before
    if start_x <= x[0]:
        if start_x < x[0]:
            flag = True
        if end_x <= x[0]:
            return y[0], flag
        else:
            area += y[0] * (x[0] - start_x)

    # Handle stuff inside
    for i in range(num_pts - 1):
        # Set beginning values
        if x[i] >= start_x:
            last_x = x[i]
            last_y = y[i]
            first_found = True
        elif x[i + 1] > start_x:
            t = (start_x - x[i]) / (x[i + 1] - x[i])
            last_x = start_x
            last_y = ((y[i + 1] * t) + (y[i] * (1.0 - t)))
            first_found = True

        # Set ending values
        if first_found:
            if x[i + 1] < end_x:
                area += ((last_y + y[i + 1]) / 2.0) * (x[i + 1] - last_x)
            else:
                t = (end_x - x[i]) / (x[i + 1] - x[i])
                this_y = ((y[i + 1] * t) + (y[i] * (1.0 - t)))
                area += ((last_y + this_y) / 2.0) * (end_x - last_x)
                break

    # Handle stuff after
    if end_x >= x[num_pts - 1]:
        if end_x > x[num_pts - 1]:
            flag = True
        if start_x >= x[num_pts - 1]:
            return y[num_pts - 1], flag
        else:
            area += y[num_pts - 1] * (end_x - x[num_pts - 1])
    return area / (end_x - start_x), flag
