"""Utility methods for the tool runners."""
__copyright__ = "(C) Copyright Aquaveo 2020"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import math

# 2. Third party modules
from rtree import index
from shapely.geometry import LineString, Point

# 3. Aquaveo modules
from xms.coverage.polygons.polygon_orienteer import get_polygon_point_lists
from xms.data_objects.parameters import FilterLocation
from xms.grid.geometry import geometry as geom
from xms.grid.ugrid import UGrid
from xms.guipy.data.target_type import TargetType
from xms.tool.algorithms.coverage import polygons_from_arcs as pfa
from xms.tool_gui.xms_data_handler import convert_to_geodataframe

# 4. Local modules
from xms.ewn.data import ewn_cov_data_consts as consts


def _logger():
    """Returns the logger for the module."""
    return logging.getLogger('xms.ewn')


def get_ewn_polygon_input(cov_geom, ewn_comp, inside_polys):
    """Get the polygons of an EWN Features coverage as input to the insertion tool.

    Args:
        cov_geom (:obj:`data_objects.parameters.Coverage`): The coverage geometry dump
        ewn_comp (:obj:`EwnCoverageComponent`): The coverage's component data
        inside_polys (:obj:`bool`): True if the inside arc definitions should be included in the input data

    Returns:
        (:obj:`list[dict]`): List of the polygon outside ring locations and its attributes
    """
    types_to_skip = [consts.FEATURE_TYPE_UNASSIGNED, 1, 2]
    return _get_polygon_input(cov_geom, ewn_comp, inside_polys, types_to_skip)


def get_ewn_arc_input(cov_geom, ewn_comp, top_width=None, elevation=None):
    """Get the arcs of an EWN Features coverage as input to the insertion tool.

    Args:
        cov_geom (:obj:`data_objects.parameters.Coverage`): The coverage geometry dump
        ewn_comp (:obj:`EwnCoverageComponent`): The coverage's component data
        top_width (:obj:`float`): The top width of the arc
        elevation (:obj:`float`): The elevation of the arc

    Returns:
        (:obj:`list[dict]`): List of the arc locations and its attributes
    """
    return _get_arc_input(cov_geom, ewn_comp, top_width, elevation)


def get_levee_polygon_input(cov_geom, ewn_comp, inside_polys):
    """Get the polygons of a Levee coverage as input to the insertion tool.

    Args:
        cov_geom (:obj:`data_objects.parameters.Coverage`): The coverage geometry dump
        ewn_comp (:obj:`EwnCoverageComponent`): The coverage's component data
        inside_polys (:obj:`bool`): True if the inside arc definitions should be included in the input data

    Returns:
        (:obj:`list[dict]`): List of the polygon outside ring locations and its attributes
    """
    classifications = ['Unassigned'] + ewn_comp.data.meta_data['Feature type/Region'].tolist()
    types = [x for x in range(len(classifications))]
    types.pop(types.index(consts.FEATURE_TYPE_ADCIRC_LEVEE))
    return _get_polygon_input(cov_geom, ewn_comp, inside_polys, types)


def _get_polygon_input(cov_geom, ewn_comp, inside_polys, types_to_skip):
    """Get the polygons of a Levee coverage as input to the insertion tool.

    Args:
        cov_geom (:obj:`data_objects.parameters.Coverage`): The coverage geometry dump
        ewn_comp (:obj:`EwnCoverageComponent`): The coverage's component data
        inside_polys (:obj:`bool`): True if the inside arc definitions should be included in the input data
        types_to_skip (:obj:`list`): list of ints of classifications to skip

    Returns:
        (:obj:`list[dict]`): List of the polygon outside ring locations and its attributes
    """
    all_polys = []
    polys = cov_geom.polygons
    for poly in polys:
        comp_id = ewn_comp.get_comp_id(TargetType.polygon, poly.id)
        if comp_id is None or comp_id < 0:
            continue
        poly_atts = ewn_comp.data.get_poly_atts(comp_id)
        if int(poly_atts.classification) in types_to_skip or int(poly_atts.insert_feature.item()) == 0:
            continue
        pts = get_polygon_outside_points_ccw(poly, poly_atts.elevation.item())
        poly_dict = {
            'polygon_outside_pts_ccw': pts,
            'polygon_atts': poly_atts,
            'polygon_name': poly_atts.polygon_name.item(),
            'classification': poly_atts.classification.item(),
            'elevation': poly_atts.elevation.item(),
            'specify_slope': poly_atts.specify_slope.item(),
            'slope': poly_atts.slope.item(),
            'max_distance': poly_atts.maximum_slope_distance.item(),
            'elevation_method': poly_atts.elevation_method.item(),
            'cov_uuid': cov_geom.uuid,
            'polygon_id': poly.id,
            'transition_method': poly_atts.transition_method.item(),
            'transition_distance': poly_atts.transition_distance.item(),
            'quadtree_refinement_length': poly_atts.get('quadtree_refinement_length', 10),
            'arc_pts': None,
        }
        if poly_atts.classification.item() == consts.FEATURE_TYPE_ADCIRC_LEVEE:
            poly_dict['levee_arcs'] = get_levee_arcs(poly)
        if poly_atts.transition_method.item() == consts.TRANSITION_METHOD_POLYGON:
            for poly1 in polys:
                if poly1.id == poly.id:
                    continue
                pts = get_polygon_outside_points_ccw(poly1)
                if geom.point_in_polygon_2d(pts, poly_dict['polygon_outside_pts_ccw'][0]) == 1:
                    poly_dict['transition_poly_pts'] = pts
        if inside_polys:
            poly_dict['polygon_inside_pts_cw'] = _get_polygon_inside_points_cw(poly)
        all_polys.append(poly_dict)
    return all_polys


def _get_arc_input(cov_geom, ewn_comp, a_top_width=None, a_elevation=None):
    all_arcs = []
    arcs = cov_geom.arcs
    for arc in arcs:
        comp_id = ewn_comp.get_comp_id(TargetType.arc, arc.id)
        if comp_id is None or comp_id < 0:
            continue
        arc_atts = ewn_comp.data.get_arc_atts(comp_id)
        if int(arc_atts.insert_feature.item()) == 0:
            continue

        top_width = a_top_width if a_top_width is not None else arc_atts.top_width.item()
        elevation = a_elevation if a_elevation is not None else arc_atts.elevation.item()
        poly_pts = get_polygon_from_arc(arc, top_width, cov_geom, elevation)

        arc_pts = None
        # If the top width is zero. filter poly pts to come to point on arc ends
        if top_width == 0.0:
            arc_pts = get_arc_points(arc)
            # remove point before and after arc last point
            poly_pts_2d = [[pt[0], pt[1]] for pt in poly_pts]
            arc_last_pt_idx_in_poly = poly_pts_2d.index([arc_pts[-1][0], arc_pts[-1][1]])
            poly_pts.pop(arc_last_pt_idx_in_poly + 1)
            poly_pts.pop(arc_last_pt_idx_in_poly - 1)
            # remove point before and after arc first point
            poly_pts.pop(-3)
            poly_pts.pop(-1)
            poly_pts.pop(0)
            # add arc first and last point
            poly_pts.insert(0, [poly_pts[-1][0], poly_pts[-1][1], poly_pts[-1][2]])

        arc_dict = {
            'polygon_outside_pts_ccw': poly_pts,
            'arc_atts': arc_atts,
            'arc_name': arc_atts.arc_name.item(),
            'classification': '',
            'elevation': arc_atts.elevation.item(),
            'specify_slope': arc_atts.use_slope.item(),
            'slope': arc_atts.side_slope.item(),
            'max_distance': arc_atts.maximum_slope_distance.item(),
            'elevation_method': consts.ELEVATION_METHOD_CONSTANT,
            'cov_uuid': cov_geom.uuid,
            'polygon_id': arc.id,
            'transition_method': consts.TRANSITION_METHOD_FACTOR,
            'transition_distance': arc_atts.transition_distance.item() if 'transition_distance' in arc_atts else 0.0,
            'arc_pts': arc_pts,  # If the arc has top width of 0 we need postprocess
        }
        all_arcs.append(arc_dict)
    return all_arcs


def get_levee_arcs(poly):
    """Get 2 arcs from the polygon that will make up a levee.

    Args:
        poly (:obj:`data_objects.parameters.Polygon`): The data_objects polygon

    Returns:
        (:obj:`tuple(list[x,y,z], list[x,y,z])`): tuple of list of arc coordinates
    """
    arcs = poly.arcs
    if len(arcs) != 4:
        msg = f'Polygon id: {poly.id} invalid. Levees must be defined using 4 arcs.'
        logging.getLogger('xms.ewn').error(msg)
        raise RuntimeError
    # find the 2 long arcs that should make up the levee
    dist = [0] * 4
    points = []
    for j, arc in enumerate(arcs):
        points.append([(p.x, p.y, p.z) for p in arc.get_points(FilterLocation.PT_LOC_ALL)])
        for i in range(1, len(points[j])):
            p0 = points[j][i - 1]
            p1 = points[j][i]
            dist[j] += geom.distance_2d(p0, p1)
    max_dist = max(dist)
    idx = dist.index(max_dist)
    dist[idx] = 0
    max_dist = max(dist)
    idx2 = dist.index(max_dist)
    if len(points[idx]) != len(points[idx2]):
        msg = f'Polygon id: {poly.id} invalid. Levees arcs must have the same number segements.'
        logging.getLogger('xms.ewn').error(msg)
        raise RuntimeError
    return points[idx], points[idx2]


def _get_polygon_inside_points_cw(poly):
    """Get the inside polygon rings in cw order (last point = first point).

    Args:
        poly (:obj:`data_objects.parameters.Polygon`): The data_objects polygon

    Returns:
        (:obj:`list[tuple]`): The outside polygon ring points as x,y,z tuples
    """
    pt_list = get_polygon_point_lists(poly)
    pt_list.pop(0)
    return pt_list


def get_polygon_outside_points_ccw(poly, override_z=None):
    """Get the outside ring of polygon points in ccw order (last point = first point).

    Args:
        poly (:obj:`data_objects.parameters.Polygon`): The data_objects polygon
        override_z (:obj:`float`): If provided, will be used for the arc elevation

    Returns:
        :obj:`list[tuple]`): The outside polygon ring points as x,y,z tuples
    """
    pt_list = get_polygon_point_lists(poly)
    points = pt_list[0]
    ret = [[pt[0], pt[1], override_z if override_z is not None else pt[2]] for pt in points]
    ret.reverse()
    return ret


def get_arc_points(arc):
    """Gets the points from an xms arc as a list of tuples.

    Args:
        arc (:obj:`data_objects.parameter.Arc`): Arc from xms coverage
        arc_dir (:obj:`bool`): Flag to know the arc direction relative to a polygon. If the value is True then
            the inside of the polygon is on the left side of the arc.

    Returns:
        (:obj:`list[tuple(x,y,z)]`): xyz coords of arc points
    """
    points = [(p.x, p.y, p.z) for p in arc.get_points(FilterLocation.PT_LOC_ALL)]
    return points


def get_polygon_from_arc(arc, poly_width, cov_geom, elevation) -> list[tuple]:
    """Get a polygon from an arc.

    Args:
        arc (:obj:`data_objects.parameter.Arc`): Arc from xms coverage
        poly_width (:obj:`float`): The width of the polygon
        cov_geom (:obj:`data_objects.parameters.Coverage`): The coverage geometry dump
        elevation (:obj:`float`): The elevation of the polygon

    Returns:
        (:obj:`list[tuple(x,y,z)]`): xyz coords of polygon points
    """
    # Get Average Spacing
    pts = get_arc_points(arc)
    main_arc_shapely = LineString(pts)
    spacings = []
    for i in range(len(main_arc_shapely.coords) - 1):
        point1 = Point(main_arc_shapely.coords[i])
        point2 = Point(main_arc_shapely.coords[i + 1])
        spacings.append(point1.distance(point2))
    average_spacing = sum(spacings) / len(spacings)
    num_elem = math.ceil(poly_width / average_spacing)
    if num_elem == 0:
        num_elem = 2
    elif num_elem % 2 != 0:
        num_elem += 1

    # Calculate top width for lines
    if poly_width == 0.0:
        poly_width = average_spacing * 1.0e-1

    # Arc Data
    arc_data = []
    arc_data.append(
        {
            'id': arc.id,
            'arc_pts': pts,
            'cov_geom': convert_to_geodataframe(coverage=cov_geom, default_wkt=''),
            'start_node': arc.start_node.id,
            'end_node': arc.end_node.id,
            'element_width': poly_width / num_elem,
            'number_elements': num_elem,
        }
    )

    # Use tool to generate a polygon from an arc
    arcs_to_polys = pfa.PolygonsFromArcs(
        arc_data=arc_data, new_cov_name='temp-name', logger=_logger(), pinch_ends=False, min_seg_length=None, wkt=None
    )
    gdf = arcs_to_polys.generate_coverage()
    arc_line_strings = list(gdf[gdf['geometry_types'] == 'Arc']['geometry'])
    polygon_points = []
    for i, arc in enumerate(arc_line_strings):
        x, y = arc.coords.xy
        arc_pts = [[p[0], p[1], elevation] for p in zip(x, y)]
        if i < len(arc_line_strings) - 1:
            polygon_points.extend(arc_pts[:-1])
        else:
            polygon_points.extend(arc_pts)
    poly_pts = polygon_points
    return poly_pts


def post_process_ugrid_for_line(poly, ugrid) -> UGrid:
    """Post process the ugrid for a line.

    Args:
        poly (:obj:`dict`): Polygon data

    Returns:
        (:obj:`xms.grid.ugrid.UGrid`): The post processed ugrid
    """
    # Setup RTree to find closest points
    poly_x, poly_y, _ = zip(*poly['polygon_outside_pts_ccw'])
    bbox = [min(poly_x), min(poly_y), max(poly_x), max(poly_y)]
    ug = ugrid
    locs = ug.locations
    idxs_in_box = [idx for idx, loc in enumerate(locs) if in_box(loc, bbox)]
    locs_in_box = [locs[idx] for idx in idxs_in_box]
    boxes = [(p[0], p[1], p[0], p[1]) for p in locs_in_box]

    def generator_func():
        for i, b in enumerate(boxes):
            yield i, b, b

    pt_rtree = index.Index(generator_func())

    # Find closest points
    arc_pts = poly['arc_pts']
    close_pt_idxs = (
        [list(pt_rtree.nearest((arc_pts[0][0], arc_pts[0][1]), 1))] +  # NOQA: W504
        [list(pt_rtree.nearest((p[0], p[1]), 2))
         for p in arc_pts[1:-1]] + [list(pt_rtree.nearest((arc_pts[-1][0], arc_pts[-1][1]), 1))]
    )

    # Zip tear back together
    min_idxs = [min(idxs) for idxs in close_pt_idxs]
    set_min_idxs = set(min_idxs)
    remove_lookup = {}
    for i, idxs in enumerate(close_pt_idxs):
        for idx in idxs:
            if idx not in set_min_idxs:
                remove_lookup[idx] = min_idxs[i]
    set_close_pt_idxs = set([idx for idxs in close_pt_idxs for idx in idxs])
    remove_idx = set_close_pt_idxs - set(min_idxs)

    for arc_pt, min_idx in zip(arc_pts, min_idxs):
        locs[min_idx] = arc_pt
    new_locs = [locs[idx] for idx in range(len(locs)) if idx not in remove_idx]

    # Create old to new index lookup
    old_to_new_idx = [-1] * ug.point_count
    cnt = 0
    for idx in range(len(locs)):
        if idx in remove_lookup:
            old_to_new_idx[idx] = old_to_new_idx[remove_lookup[idx]]
        else:
            old_to_new_idx[idx] = cnt
            cnt += 1

    # Build cell stream
    cs = ug.cellstream
    new_cs = []
    cell_idx = -1
    cnt = 0
    while cnt < len(cs):
        cell_idx += 1
        cell_type = cs[cnt]
        num_pts = cs[cnt + 1]
        start = cnt + 2
        end = start + num_pts
        cell_pts = cs[start:end]
        cnt = end
        cell_pt_idx = [old_to_new_idx[idx] for idx in cell_pts]
        if any(idx in remove_idx for idx in cell_pts):
            # Check cell for positive area else drop cell
            poly_locs = [new_locs[idx_2] for idx_2 in cell_pt_idx]
            if not geom.polygon_area_2d(poly_locs) > 0.0:
                continue  # skip this cell with zero area
        new_cs.extend([cell_type, num_pts] + cell_pt_idx)

    return UGrid(new_locs, new_cs)


def in_box(loc, box):
    """Determines if a location is in a box.

    Args:
        loc (:obj:`list`): Location to test
        box (:obj:`list`): Box to test against

    Returns:
        bool: True if the location is in the box
    """
    if loc[0] < box[0] or loc[0] > box[2] or loc[1] < box[1] or loc[1] > box[3]:
        return False
    return True
