"""Utility methods used by the mapping modules."""

# 1. Standard Python modules
import binascii
import cmath
import io
import math
import os
import sys

# 2. Third party modules
import numpy as np
from shapely.geometry import LineString, Point

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv
from xms.interp.interpolate import InterpIdw

# 4. Local modules

FEET_PER_METER = 3.28083989501


def _factor_from_latitude(latitude):
    """Computes meters, decimal degrees conversion factor from latitude.

    Args:
        latitude (:obj:`float`): the latitude

    Returns:
        (:obj:`float`): conversion factor
    """
    return 111.32 * 1000 * math.cos(latitude * (math.pi / 180))


def _buffer_segment_side(pt1_coords, pt2_coords, extend_length):
    """Extend a line a specified length beyond the last point in the segment.

    Args:
        pt1_coords (:obj:`tuple`): The coordinates of the first node in the levee node pair
        pt1_coords (:obj:`tuple`): The coordinates of the second node in the levee node pair
        extend_length (:obj:`float`): The distance to extend each endpoint of the segment. Provides a tolerance for
            intersections.

    Returns:
        (:obj:`tuple(float, float)`): The coordinates of the second point, extended along the segment by extend_length
    """
    # Calculate the x,y change between the points
    difference = complex(pt2_coords[0], pt2_coords[1]) - complex(pt1_coords[0], pt1_coords[1])
    # Convert the difference to polar coordinates
    distance, angle = cmath.polar(difference)
    # Calculate a new x,y change based on the angle and desired distance
    displacement = cmath.rect(extend_length, angle)
    # Add the displacement end point
    xy3 = displacement + complex(pt2_coords[0], pt2_coords[1])
    return xy3.real, xy3.imag


def grid_projection_mismatches_display(native_projection, display_projection):
    """Check if the projections of the domain mesh and the current display do not match.

    Notes:
        This is important to us because the state of the mapped BC/tides is dependent on the mesh data at the time
        of mapping. Need to always map and export while working in the mesh object's projection.

    Args:
        native_projection (:obj:`Projection`): The mesh's native data_objects projection
        display_projection (:obj:`Projection`): The current display data_objects projection

    Returns:
        (:obj:`str`): The error message to log or empty string if projections match
    """
    if native_projection.coordinate_system != display_projection.coordinate_system \
            or native_projection.coordinate_zone != display_projection.coordinate_zone \
            or native_projection.horizontal_units != display_projection.horizontal_units \
            or native_projection.vertical_units != display_projection.vertical_units \
            or native_projection.vertical_datum != display_projection.vertical_datum:
        return 'Display projection does not match the ADCIRC mesh. \n' \
               '\t\tCannot export unless the projections match. \n' \
               '\t\tEnsure mesh projection is the desired projection for ADCIRC and \n' \
               '\t\tset display projection to match to allow exporting.'
    return ''


def buffer_segment(segment):
    """Buffer a line segment on either side by its width (assumes only two points), effectively tripling its length.

    Args:
        segment (:obj:`LineString`): The Linesting to buffer

    Returns:
        (:obj:`LineString`): The buffered segment
    """
    pt1_coords = segment.coords[0]
    pt2_coords = segment.coords[1]
    pt1x, pt1y = _buffer_segment_side(pt2_coords, pt1_coords, segment.length)  # Extend beyond first node in segment
    pt2x, pt2y = _buffer_segment_side(pt1_coords, pt2_coords, segment.length)  # Extend beyond last node in segment
    # Preserve the Z just so it gets passed along, usually need to recompute it anyway
    return LineString([Point(pt1x, pt1y, pt1_coords[2]), Point(pt2x, pt2y, pt2_coords[2])])


def get_parametric_lengths(nodestring_pts):
    """Get the parametric lengths along a nodestring for each node of the nodestring.

    Args:
        nodestring_pts (:obj:`list`): List of x,y coordinate tuples of the snapped arc locations

    Returns:
        (:obj:`list`): List of the parametric lengths between 0.0 and 1.0
    """
    num_pts = len(nodestring_pts)
    para_lengths = [0.0]
    actual_lengths = []
    total_length = 0.0
    prev_pt = (nodestring_pts[0][0], nodestring_pts[0][1])
    for i in range(1, num_pts):
        now_pt = (nodestring_pts[i][0], nodestring_pts[i][1])
        diff_pts = np.subtract(prev_pt, now_pt)
        dist = np.linalg.norm(diff_pts)
        total_length += dist
        actual_lengths.append(total_length)
        prev_pt = now_pt
    for seg_length in actual_lengths:
        para_lengths.append(seg_length / float(total_length))
    return para_lengths


def map_levee_atts(sorted_levee, p_len):
    """Get height and flow coefficients for a levee boundary condition along all of its snapped nodes.

    Args:
        sorted_levee (:obj:`xarray.Dataset`): The levee dataset sorted by parametric length
        p_len (:obj:`float`): Parametric length to calculate coefficients at

    Returns:
        (:obj:`tuple`): (height, super-critical coefficient, sub-critical coefficient, index of parametric length)
    """
    para_lengths = sorted_levee['Parametric __new_line__ Length'].data
    zcrests = sorted_levee['Zcrest (m)'].data
    return get_levee_height(p_len, para_lengths, zcrests)


def get_levee_height(p_len, para_lengths, zcrests):
    """Get height and flow coefficients for a levee boundary condition along all of its snapped nodes.

    Args:
        p_len (:obj:`float`): Parametric length to calculate coefficients at
        para_lengths (:obj:`Sequence`): The parametric lengths of the levee
        zcrests (:obj:`Sequence`): The levee Z-crest curve

    Returns:
        (:obj:`tuple`):
            (height, super-critical coefficient, sub-critical coefficient, index of the parametric lengths where we
            mapped data from)
    """
    height = 0.0
    found_height = False

    if len(para_lengths) == 1:
        return zcrests[0], 0
    idx = 0
    for i in range(len(para_lengths)):
        levee_len = para_lengths[i]
        if p_len == levee_len:
            found_height = True
            height = zcrests[i]
            idx = i
            break
        elif p_len < levee_len:
            found_height = True
            prev_levee_len = para_lengths[i - 1]
            len_diff = levee_len - prev_levee_len
            len_diff_actual = p_len - prev_levee_len
            per1 = len_diff_actual / len_diff
            per0 = 1.0 - per1
            # Use this next line for a blended average
            # height = (height0 * per0) + (height1 * per1)
            if per0 > per1:
                idx = i - 1
                height = zcrests[i - 1]
            else:
                idx = i
                height = zcrests[i]
            break
    if found_height is False and len(para_lengths) > 0:
        height = zcrests[-1]
    return height, idx


def populate_levee_atts_from_arcs(levee_locs, parametric_targets=None):
    """Populate levee attribute table from the coverage arc locations.

    Args:
        levee_locs (:obj:`list`): The levee arc locations as x,y,z, tuples. If a levee pair should be a 2D list
        parametric_targets (:obj:`numpy.ndarray`): The target parametric lengths to compute Z-values for. If not
            provided, will compute at all arc locations.

    Returns:
        (:obj:`numpy.ndarray`): The parametric lengths and Zcrest value pairs for the levee pair or levee outflow
    """
    is_levee_outflow = len(levee_locs) == 1
    parametric_lengths1 = get_parametric_lengths(levee_locs[0])
    parametric_lengths2 = get_parametric_lengths(levee_locs[1]) if not is_levee_outflow else None
    zcrests1 = np.array([location[2] for location in levee_locs[0]])
    zcrests2 = np.array([location[2] for location in levee_locs[1]]) if not is_levee_outflow else None
    if parametric_targets is not None:  # Already have existing rows of levee data, only update those.
        lengths_and_z1 = []
        lengths_and_z2 = []
        for para_length in parametric_targets:
            height, _ = get_levee_height(para_length, parametric_lengths1, zcrests1)
            lengths_and_z1.append((para_length, height))
            if not is_levee_outflow:
                height, _ = get_levee_height(para_length, parametric_lengths2, zcrests2)
                lengths_and_z2.append((para_length, height))
        if not is_levee_outflow:
            lengths_and_z1 = np.mean([lengths_and_z1, lengths_and_z2], axis=0)
    else:  # Compute at all arc locations
        parametric_lengths1_lookup = set(parametric_lengths1)
        parametric_lengths1 = np.array(parametric_lengths1)
        lengths_and_z1 = [(length, location) for length, location in zip(parametric_lengths1, zcrests1)]
        if not is_levee_outflow:
            parametric_lengths2_lookup = set(parametric_lengths2)
            parametric_lengths2 = np.array(parametric_lengths2)
            lengths_and_z2 = [(length, location) for length, location in zip(parametric_lengths2, zcrests2)]

            for len1 in parametric_lengths1:
                if len1 not in parametric_lengths2_lookup:
                    lengths_and_z2_arr = np.array(lengths_and_z2)
                    height, _ = get_levee_height(len1, lengths_and_z2_arr[:, 0], lengths_and_z2_arr[:, 1])
                    lengths_and_z2.append((len1, height))
                    lengths_and_z2.sort()
            for len2 in parametric_lengths2:
                if len2 not in parametric_lengths1_lookup:
                    lengths_and_z1_arr = np.array(lengths_and_z1)
                    height, _ = get_levee_height(len2, lengths_and_z1_arr[:, 0], lengths_and_z1_arr[:, 1])
                    lengths_and_z1.append((len2, height))
                    lengths_and_z1.sort()
            lengths_and_z1 = np.max([lengths_and_z1, lengths_and_z2], axis=0)
    return lengths_and_z1


def linear_interp_with_idw_extrap(linear_interper, to_points, dset_reader, idw_interpers):
    """Interpolate a dataset to to_points using linear interpolation and use IDW for extrapolated points.

    Args:
        linear_interper (:obj:`InterpLinear`): Linear interpolator initially used
        to_points (:obj:`list`): List of x,y,z locations to interpolate to
        dset_reader (:obj:`xms.datasets.dataset_reader.DatasetReader`): The dataset to interpolate
        idw_interpers (:obj:`dict`): Dictionary of IDW interpolators keyed by geom UUID. If linear interpolation results
            in extrapolated values, this dictionary will be checked for an existing IDW interpolator for the geometry.
            If no IDW interpolator present, one will be created.

    Returns:
        (:obj:`list`): interp_vals with extrapolated values replaced with IDW values
    """
    linear_interper.scalars = dset_reader.values[0]
    interp_vals = linear_interper.interpolate_to_points(to_points)
    extrap_pt_idxs = linear_interper.extrapolation_point_indexes
    if len(extrap_pt_idxs) > 0:  # Have extrapolated values, use IDW for those points.
        interp_vals = list(interp_vals)  # Convert immutable tuple to mutable list
        geom_uuid = dset_reader.geom_uuid
        if geom_uuid not in idw_interpers:
            # Set up an IDW interpolator if linear interpolation resulted in extrapolated points.
            idw_interpers[geom_uuid] = InterpIdw(
                points=linear_interper.points, nodal_function='constant', number_nearest_points=2
            )
            idw_interpers[geom_uuid].set_search_options(2, False)

        idw_interp = idw_interpers[geom_uuid]
        idw_interp.scalars = linear_interper.scalars
        idw_interp_vals = idw_interp.interpolate_to_points([to_points[extrap_pt] for extrap_pt in extrap_pt_idxs])
        for extrap_idx, extrap_pt in enumerate(extrap_pt_idxs):
            interp_vals[extrap_pt] = idw_interp_vals[extrap_idx]
    return interp_vals


def levee_snap_to_segments(snap1, snap2):
    """Create a list of segments between the snapped nodes of a levee pair.

    Args:
        snap1 (:obj:`dict`): The snap result for the first levee arc
        snap2 (:obj:`dict`): The snap result for the second levee arc

    Returns:
        (:obj:`tuple[list[int], list[LineString]]`: Extents of all the segments [min_x, min_y, max_x, max_y], List of
        segments between the snapped levee node pairs as shapely objects
    """
    segments = []
    bounds = [sys.maxsize, sys.maxsize, -sys.maxsize, -sys.maxsize]  # [min_x, min_y, max_x, max_y]
    for coords1, coords2 in zip(snap1['location'], snap2['location']):
        linestring = LineString([Point(*coords1), Point(*coords2)])
        segments.append(linestring)
        segment_bounds = linestring.bounds
        bounds[0] = min(bounds[0], segment_bounds[0])
        bounds[1] = min(bounds[1], segment_bounds[1])
        bounds[2] = max(bounds[2], segment_bounds[2])
        bounds[3] = max(bounds[3], segment_bounds[3])
    return bounds, segments


def meters_to_decimal_degrees(length_meters, latitude):
    """Convert meters to decimal degrees based on the latitude.

    Args:
        length_meters (:obj:`float`): length in meters
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    return length_meters / _factor_from_latitude(latitude)


def feet_to_decimal_degrees(length_feet, latitude):
    """Convert feet to decimal degrees based on the latitude.

    Args:
        length_feet (:obj:`float`): length in feet
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    length_meter = length_feet / FEET_PER_METER
    return meters_to_decimal_degrees(length_meter, latitude)


def decimal_degrees_to_meters(length_degrees, latitude):
    """Convert decimal degrees to meters based on the latitude.

    Args:
        length_degrees (:obj:`float`): length in degrees
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    return length_degrees * _factor_from_latitude(latitude)


def decimal_degrees_to_feet(length_degrees, latitude):
    """Convert decimal degrees to feet based on the latitude.

    Args:
        length_degrees (:obj:`float`): length in degrees
        latitude (:obj:`float`): latitude in decimal degrees

    Returns:
        (:obj:`float`): length in decimal degrees
    """
    length_meter = decimal_degrees_to_meters(length_degrees, latitude)
    return length_meter * FEET_PER_METER


def check_levee_bc_file():
    """Returns the name of the check levee BC coverage dump file for viewing check levee tool results."""
    temp_dir = XmEnv.xms_environ_process_temp_directory()
    return os.path.join(temp_dir, f'adcirc_check_levee_coverage_{os.getpid()}.h5')


def check_levee_results_df_file():
    """Returns the name of the check levee results DataFrame file for viewing check levee tool results."""
    temp_dir = XmEnv.xms_environ_process_temp_directory()
    return os.path.join(temp_dir, f'adcirc_check_levee_results_{os.getpid()}.pkl')


def check_levee_results_html_file():
    """Returns the name of the check levee results HTML file for viewing check levee tool results."""
    temp_dir = XmEnv.xms_environ_process_temp_directory()
    return os.path.join(temp_dir, f'adcirc_check_levee_results_{os.getpid()}.html')


def coordinate_hash(coords1, coords2):
    """Get a hash for two lists of x,y,z coordinates.

    Args:
        coords1 (:obj:`list`): The first set of coordinates.
        coords2 (:obj:`list`): The second set of coordinates.

    Returns:
        (:obj:`str`): The hexadecimal representation of the CRC32 checksum of the concatenated byte arrays of the two
        coordinate sets.
    """
    with io.BytesIO() as iob:
        iob.write(coords1)
        iob.write(coords2)
        iob.seek(0)
        return str(hex(binascii.crc32(iob.read()) & 0xFFFFFFFF))
