"""Code for creating polygons from arcs."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
from rtree import index

# 3. Aquaveo modules

# 4. Local modules


class LinesFromPoints:
    """Tool to string _points into an arc based on nearest _neighbor information."""

    def __init__(self, points, num_nearest_neighbors, connect_isolated_pts, enforce_max_dist, max_dist):
        """Initializes the class."""
        self._used_pts = None
        self._neighbor = None
        self._d2 = None
        self._order = None
        self._points = points
        self._nn = num_nearest_neighbors
        self._connect_isolated_pts = connect_isolated_pts
        self._enforce_max_dist = enforce_max_dist
        self._max_dist2 = max_dist * max_dist

    def _common_neighbor(self, loc_idx):
        """Find common _neighbor given the point index.

        Args:
            loc_idx (int): point index.

        Returns:
            (int): index of _neighbor point or -1
        """
        for i in range(self._nn):
            n = self._neighbor[loc_idx][i]
            d2 = self._d2[loc_idx][i]
            if n not in self._used_pts and loc_idx in self._neighbor[n]:
                return n, d2
        return -1, 0.0

    def _add_pts_to_string(self, pt_str):
        """Adds _points of common neighbors to the end of the string.

        Args:
            pt_str (list): list of point indices in string.
        """
        done = False
        while not done:
            start = pt_str[0]
            end = pt_str[-1]
            start_neigh, start_d2 = self._common_neighbor(start)
            end_neigh, end_d2 = self._common_neighbor(end)

            # enforce the maximum distance check if it is on
            if self._enforce_max_dist:
                if start_neigh != -1 and start_d2 > self._max_dist2:
                    start_neigh = -1
                if end_neigh != -1 and end_d2 > self._max_dist2:
                    end_neigh = -1

            if start_neigh == -1 and end_neigh == -1:
                done = True
                continue

            if start_neigh == -1:
                pass
            elif end_neigh == -1:
                pt_str.reverse()
                end_neigh = start_neigh
            else:
                if start_d2 < end_d2:
                    pt_str.reverse()
                    end_neigh = start_neigh
            pt_str.append(end_neigh)
            self._used_pts.add(end_neigh)

    def lines_from_points(self):
        """
        Grabs the _points and strings them together.

        Returns:
            (list): A list of the line strings
        """
        self._setup_point_info()
        self._used_pts = set()

        # building arcs (strings of _points)
        all_pt_strs = []
        for item in self._order:
            # find a starting segment
            pt_idx = item[2]
            if pt_idx in self._used_pts:
                continue

            n, d2 = self._common_neighbor(pt_idx)
            if n == -1:
                continue
            if self._enforce_max_dist and d2 > self._max_dist2:
                continue

            pt_str = [pt_idx, n]
            self._used_pts.update(pt_str)

            # now add _neighbor _points to the current arc
            self._add_pts_to_string(pt_str)

            # add the current arc to the list of arcs
            all_pt_strs.append(pt_str)

        # connect stand along _points if enabled
        if self._connect_isolated_pts:
            # find the _points that are not used (stand alones)
            unused_pt_idx = []
            for item in self._order:
                pt_idx = item[2]
                if pt_idx in self._used_pts:
                    continue
                # if enforcing a max dist, make sure this point is not
                # "too isolated" to connect in
                if self._enforce_max_dist and item[0] > self._max_dist2:
                    continue
                unused_pt_idx.append(pt_idx)

            self._used_pts.update(unused_pt_idx)
            for idx in unused_pt_idx:
                all_pt_strs.append([idx, self._neighbor[idx][0]])

        return all_pt_strs

    def _setup_point_info(self):
        """Grabs the _points and strings them together."""
        # build an rtree of the _points
        boxes = [(p[0], p[1], p[0], p[1]) for p in self._points]

        def generator_func():
            for i, b in enumerate(boxes):
                yield i, b, b

        # create rtree of the input_points
        point_rtree = index.Index(generator_func())

        self._neighbor = [list(point_rtree.nearest((p[0], p[1]), self._nn + 1))[1:] for p in self._points]

        def dist2(p1, p2):
            dx = p1[0] - p2[0]
            dy = p1[1] - p2[1]
            return dx * dx + dy * dy

        pts = self._points
        dist = []
        for ni in range(self._nn):
            dist.append([dist2(p, pts[self._neighbor[i][ni]]) for i, p in enumerate(pts)])

        # build array of point indices in ascending _order of D1 (secondary sort on D2)
        self._order = [[dist[0][pt_idx], dist[1][pt_idx], pt_idx] for pt_idx in range(len(self._points))]
        # sorts on the second distance when collisions occur
        self._order.sort()

        self._d2 = [[0.0] * self._nn for _ in self._points]
        for j, d2_list in enumerate(dist):
            for i, d2 in enumerate(d2_list):
                self._d2[i][j] = d2
