"""UGrid2dMerger class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import math
import sys

# 2. Third party modules
from matplotlib import path
import numpy as np
from rtree import index
from shapely.geometry import LineString, MultiPolygon, Point, Polygon
from shapely.ops import nearest_points

# 3. Aquaveo modules
from xms.constraint.ugrid_builder import UGridBuilder
from xms.extractor.ugrid_2d_polyline_data_extractor import UGrid2dPolylineDataExtractor
from xms.grid.ugrid import UGrid
from xms.interp.interpolate.interp_linear import InterpLinear
from xms.mesher import generate_mesh
from xms.mesher import meshing

# 4. Local modules
from xms.tool.algorithms.coverage.grid_cell_to_polygon_coverage_builder import GridCellToPolygonCoverageBuilder
from xms.tool.algorithms.geometry import geometry

PT_MATCH_POLY_PT = 2
PT_IN_POLY = 1
PT_OUT_POLY = -1


def xy_distance(p1, p2):
    """Gets the xy distance between 2 points.

    Args:
        p1 (iterable): x, y coords
        p2 (iterable): x, y coords

    Returns:
        (float): the x, y distance
    """
    dx = p1[0] - p2[0]
    dy = p1[1] - p2[1]
    return math.sqrt(dx * dx + dy * dy)


class PolygonOperator:
    """Class to perform set operations on polygons."""

    def __init__(self, g1_boundary, g1_boundary_idxs, g1_boundary_rtree, ug2_locs, ug2_pt_classify, polys,
                 poly_buffer, pt_tol, logger):
        """Initializes the class.

        Args:
            g1_boundary (list): x,y,z coords of grid_1 boundary
            g1_boundary_idxs (list): index of each point in grid_1
            g1_boundary_rtree (rtree): spatial index
            ug2_locs (list): x,y,z coords of grid_2 points
            ug2_pt_classify (list): classification of grid_2 points relative to grid_1 boundary
            polys (list(list)): list of lists of grid_2 point indexes that define polygons
            poly_buffer (float): the distance to offset the boundary from grid_1
            pt_tol (float): tolerance for comparing if points match locations
            logger (logging.logger): logger
        """
        self._logger = logger if logger is not None else logging.getLogger('xms.tool')
        self._polys = polys
        self._g1_boundary = g1_boundary
        self._g1_poly = Polygon(g1_boundary)
        self._g1_poly_inset = self._g1_poly.buffer(-poly_buffer, resolution=1)
        b = poly_buffer * 0.5
        while len(self._g1_poly_inset.exterior.coords) < 1:
            self._g1_poly_inset = self._g1_poly.buffer(-b, resolution=1)
            b *= 0.5
        self._g1_poly_mp_path = path.Path([(p[0], p[1]) for p in self._g1_poly_inset.exterior.coords])
        self._g1_poly_dict = {(p[0], p[1]): i for i, p in enumerate(g1_boundary[:-1])}
        self._g1_poly_dict_save = {k: g1_boundary_idxs[v] for k, v in self._g1_poly_dict.items()}
        self._g1_rtree = g1_boundary_rtree
        self._ug2_locs = ug2_locs
        self._ug2_pt_classify = ug2_pt_classify
        self._poly_buffer = poly_buffer
        self._pt_tol = pt_tol
        self._g2_poly_idxs = None
        self._g2_poly_pts = None
        self._g2_poly = None
        self._g2_poly_dict = None
        self._cur_poly = None
        self._cur_locs = None
        self._final_locs = []
        self._diff_polys = []
        self.polys_to_mesh = []
        self.polys_to_mesh_dicts = []
        self.polys_to_mesh_interior = []

    def calc_overlap_polys(self):
        """Calculates all of the overlap polygons."""
        for p in self._polys:
            self._g2_poly_idxs = p
            self._calc_overlap()

    def _calc_overlap(self):
        """Finds the overlap polygon between grid_1 and grid_2."""
        locs = self._ug2_locs
        idxs = self._g2_poly_idxs[0]
        classify = self._ug2_pt_classify
        self._g2_poly_pts = [(locs[idx][0], locs[idx][1], locs[idx][2]) for idx in idxs]

        self._adjust_g2_poly_points_classified_as_in_but_location_is_out_of_g1_poly()
        self._g2_poly = Polygon(self._g2_poly_pts)

        classes = set([PT_IN_POLY, PT_MATCH_POLY_PT])
        self._g2_poly_dict = {
            (p[0], p[1]): i for i, p in enumerate(self._g2_poly_pts) if classify[idxs[i]] not in classes
        }

        diff = self._g2_poly.difference(self._g1_poly)
        if not diff:  # no resulting geometry
            return
        if isinstance(diff, MultiPolygon):
            self._diff_polys = list(diff.geoms)
        elif isinstance(diff, Polygon):
            self._diff_polys = [diff]
        for p in self._diff_polys:
            self._cur_poly = p
            self._snap_diff_to_grid_pts()

    def _adjust_g2_poly_points_classified_as_in_but_location_is_out_of_g1_poly(self):
        """Adjust g2 points that are classified as in g1_poly but the location is out.

        Some g2 poly points may be outside of the g1_boundary but are classified as in because
        the user can specify a buffer distance for the g1_polygon. If we don't adjust the g2_poly
        then we won't get the expected intersection with the g1 polygon.
        test_case_22_sub_2 and test_case_22_sub_4 have this problem
        """
        idxs = self._g2_poly_idxs[0]
        check_pts = [(p[0], p[1]) for p in self._g2_poly_pts]
        pt_in_poly = self._g1_poly_mp_path.contains_points(check_pts)
        for i, idx in enumerate(idxs):
            p = self._ug2_locs[idx]
            pt_class = self._ug2_pt_classify[idx]
            if pt_class == PT_MATCH_POLY_PT or (pt_class == PT_IN_POLY and not pt_in_poly[i]):
                sh_pt = Point(p[0], p[1])
                poly_pt, _ = nearest_points(self._g1_poly_inset, sh_pt)
                poly_pt = poly_pt.coords[0]
                self._g2_poly_pts[i] = (poly_pt[0], poly_pt[1], p[2])

    def _snap_diff_to_grid_pts(self):
        """Make sure all points in the output polygons are snapped to the grid points."""
        self._cur_locs = [pt for pt in self._cur_poly.exterior.coords]

        # if the locations do not have any g2 points then don't process this polygon
        do_process_poly = False
        for pt in self._cur_locs:
            p = (pt[0], pt[1])
            if p in self._g2_poly_dict:
                do_process_poly = True
        if not do_process_poly:
            return

        self._snap_locations()
        if len(self._final_locs) < 4:
            # After snapping the locations to the grid points we may have less than 4 points
            # so this polygon has collapsed to a line on the boundary so we just skip it.
            # test case 34 hits this condition
            return
        self.polys_to_mesh.append(self._final_locs)
        # store grid point idxs for the points
        g2_poly_dict_save = {k: self._g2_poly_idxs[0][v] for k, v in self._g2_poly_dict.items()}
        self.polys_to_mesh_dicts.append((self._g1_poly_dict_save, g2_poly_dict_save))

        if len(self._cur_poly.interiors) > 0:
            self._cur_locs = [pt for pt in self._cur_poly.interiors[0].coords]
            self._snap_locations()
            self.polys_to_mesh_interior.append([])
            if len(self._final_locs) > 3:
                self.polys_to_mesh_interior[-1] = self._final_locs
        else:
            self.polys_to_mesh_interior.append([])

    def _snap_locations(self):
        """Snaps the locations to grid_1 point locations."""
        self._final_locs = []
        for pt in self._cur_locs:
            pt_to_append = pt
            # can't do this - see bug 13930
            p = (pt[0], pt[1])
            if p not in self._g1_poly_dict and p not in self._g2_poly_dict:
                # snap this point to the closest point on g1_poly_set
                idx = list(self._g1_rtree.nearest((p[0], p[1])))[0]
                p_g1 = self._g1_boundary[idx]
                pt_to_append = p_g1

            self._final_locs.append(pt_to_append)
            if len(self._final_locs) > 1:
                p_0 = self._final_locs[-1]
                p_1 = self._final_locs[-2]
                if (p_0[0], p_0[1]) == (p_1[0], p_1[1]):
                    self._final_locs.pop()

        poly = Polygon(self._final_locs)
        if poly.area <= 0.0:
            self._final_locs = []
            return

        if not poly.is_valid:
            self._final_locs = self._cur_locs


class UGrid2dMerger:
    """Class to merge two 2d ugrids."""

    def __init__(self, grid_1, grid_2, pt_tol=1.0e-9, poly_buffer=None, logger=None, stitch_grids=False):
        """Initializes the class.

        Args:
            grid_1 (xms.constraint.UGrid_2d): the priority grid. This grid will be preserved.
            grid_2 (xms.constraint.UGrid_2d): the secondary grid. The parts of this grid that are outside of grid_1 will
                be appended to grid 1.
            pt_tol (float): tolerance for comparing if points match locations
            poly_buffer (float): the distance to offset the boundary from grid_1
            logger (logging.logger): logger
            stitch_grids (bool): merge non-overlapping grids with shared boundary points
        """
        self.grid_1 = grid_1
        self.grid_2 = grid_2
        self.pt_tol = pt_tol
        self.poly_buffer = poly_buffer
        self.log = logger
        self.stitch_grids = stitch_grids
        self.progress_interval = 500000
        self.out_ugrid = None
        self._ug1 = self.grid_1.ugrid
        self._ug1_locs = self._ug1.locations
        self._visited_cells = None
        self._connected_cells = None
        self._curr_cell = None
        self._grids = []
        self._merger = None
        self._created_elem_dict = False
        self._g2_elem_dict = dict()
        self._show_merge_cell_datasets_message = True
        self.logger = logger if logger is not None else logging.getLogger('xms.tool')

    def _enable_profile(self):
        """Enable profiling."""
        pass
        # import cProfile
        # self.pr = cProfile.Profile()
        # self.pr.enable()

    def _profile_report(self):
        """Print the profile report."""
        pass
        # import pstats
        # import io
        # self.pr.disable()
        # s = io.StringIO()
        # sortby = 'cumulative'
        # ps = pstats.Stats(self.pr, stream=s).sort_stats(sortby)
        # ps.print_stats()
        # with open('c:/temp/profile.txt', 'w') as f:
        #     f.write(s.getvalue())

    def merge_grids(self):
        """Merge the two grids."""
        self._created_elem_dict = False
        self._enable_profile()
        self._process_grid_1()
        grid_2 = self.grid_2
        out_grid = None
        for grid in self._grids:
            self._merger = _UGrid2dMerger(grid, grid_2, self.pt_tol, self.poly_buffer, self.log, self.stitch_grids)
            self._merger.progress_interval = self.progress_interval
            self._merger.merge_grids()
            out_grid = self._merger.out_ugrid
            b = UGridBuilder()
            b.set_is_2d()
            b.set_ugrid(out_grid)
            grid_2 = b.build_grid()
        self.out_ugrid = out_grid
        self._profile_report()
        return self.out_ugrid

    def can_merge_datasets(self):
        """Check if we can merge datasets."""
        # don't try to merge dataset if we have created new points that are not in either of the
        # original input grids
        num_pts = max(self._merger.ug2_old_to_new_pt_idx) + 1
        return num_pts == self.out_ugrid.point_count

    def merge_datasets(self, p_reader, s_reader, writer):
        """Merge the two datasets.

        Args:
            p_reader (DatasetReader): the priority grid dataset
            s_reader (DatasetReader): the secondary grid dataset
            writer (DatasetWriter): the dataset writer for the merged dataset
        Returns:
            (bool): dataset was merged

        """
        if p_reader.location == 'points':
            num_val = self.out_ugrid.point_count
            new_to_old_idx = np.full(num_val, -1, dtype=np.int32)
            for i, idx in enumerate(self._merger.ug2_old_to_new_pt_idx):
                new_to_old_idx[idx] = i
        elif self._merger.can_merge_cell_datasets:
            num_val = self.out_ugrid.cell_count
            new_to_old_idx = np.full(num_val, -1, dtype=np.int32)
            elem_dict = self._get_g2_elem_dict()
            for k, v in elem_dict.items():
                new_to_old_idx[k] = v
        else:
            if self._show_merge_cell_datasets_message:
                msg = 'Unable to merge cell datasets because new cells were created while merging the input grids.'
                self.logger.info(msg)
                self._show_merge_cell_datasets_message = False
            return False

        time_count = len(p_reader.times)
        for tsidx in range(time_count):
            use_activity = False
            data, activity = p_reader.timestep_with_activity(tsidx)
            start_idx = data.shape[0]
            if activity is not None:
                use_activity = True
            s_data, s_activity = s_reader.timestep_with_activity(tsidx)
            append_data = np.asarray([s_data[new_to_old_idx[i]] for i in range(start_idx, num_val)],
                                     dtype=data.dtype)
            data = np.concatenate((data, append_data))
            if use_activity:
                append_act = np.asarray([s_activity[new_to_old_idx[i]] for i in range(start_idx, num_val)],
                                        dtype=activity.dtype)
                activity = np.concatenate((activity, append_act))
            writer.append_timestep(p_reader.times[tsidx], data, activity)

        writer.appending_finished()
        return True

    def _get_g2_elem_dict(self):
        """Merge the two datasets."""
        if self._created_elem_dict is False:
            self._g2_elem_dict.clear()
            g2_cell = -1
            g2 = self.grid_2.ugrid
            g1_cell_count = self.grid_1.ugrid.cell_count
            cellstream = g2.cellstream
            cnt = 0
            while cnt < len(cellstream):
                # cell_type = cellstream[cnt]
                g2_cell += 1
                num_pts = cellstream[cnt + 1]
                start = cnt + 2
                end = start + num_pts
                g2_nodes = cellstream[start:end]
                cnt = end
                merged_nodes = [self._merger.ug2_old_to_new_pt_idx[node] for node in g2_nodes]

                # now we know what the nodes are in the new grid
                merged_cell = self.out_ugrid.get_points_adjacent_cells(merged_nodes)
                if len(merged_cell) == 1 and merged_cell[0] >= g1_cell_count:
                    if self.out_ugrid.get_cell_centroid(merged_cell[0]) == g2.get_cell_centroid(g2_cell):
                        self._g2_elem_dict[merged_cell[0]] = g2_cell
            self._created_elem_dict = True
        return self._g2_elem_dict

    def _process_grid_1(self):
        """If grid_1 has disjoint cells then split it into multiple grids."""
        if self._ug1.cell_count < 1:
            msg = 'The priority grid must have cells. Aborting.'
            if self.log:
                raise RuntimeError(msg)
            else:
                self.logger.error(msg)
        self._remove_disjoint_points()
        self._split_disjoint_grid()

    def _remove_disjoint_points(self):
        """Remove points that are not connected to any cells."""
        ug = self._ug1
        npts = ug.point_count
        pt_to_ncells = [len(ug.get_point_adjacent_cells(i)) for i in range(npts)]
        if 0 not in pt_to_ncells:
            return
        self.logger.warning('Disjoint points on the primary grid have been removed.')
        old_to_new_idx = [-1] * npts
        new_idx = 0
        locs = self._ug1_locs
        new_locs = []
        for i in range(npts):
            if pt_to_ncells[i] > 0:
                old_to_new_idx[i] = new_idx
                new_idx += 1
                new_locs.append(locs[i])

        self.logger.info('Removing disjoint points in primary grid.')
        cell_idx = -1
        cnt = 0
        cellstream = self._ug1.cellstream
        new_cs = []
        while cnt < len(cellstream):
            cell_idx += 1
            if (cell_idx + 1) % self.progress_interval == 0:
                msg = f'Disjoint point removal - 500,000 cells processed. Current index {cell_idx}.'
                self.logger.info(msg)
            cell_type = cellstream[cnt]
            num_pts = cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cell_pts = cellstream[start:end]
            cnt = end
            new_cs.extend([cell_type, num_pts] + [old_to_new_idx[idx] for idx in cell_pts])
        co_builder = UGridBuilder()
        co_builder.set_unconstrained()
        co_builder.set_ugrid(UGrid(new_locs, new_cs))
        self.grid_1 = co_builder.build_grid()
        self._ug1 = UGrid(new_locs, new_cs)
        self._ug1_locs = self._ug1.locations

    def _split_disjoint_grid(self):
        """If grid_1 has disjoint cells then split it into multiple grids."""
        self._visited_cells = set()
        ncell = self._ug1.cell_count
        idx = 0
        while len(self._visited_cells) < ncell:
            while idx in self._visited_cells:
                idx += 1
            self._curr_cell = idx
            self._calc_connected_cells()
            if idx == 0 and len(self._connected_cells) == ncell:
                self._grids = [self.grid_1]
                break
            self._build_grid_from_connected_cells()

    def _calc_connected_cells(self):
        """Finds cells connected across edges to the current cell."""
        cells = [self._curr_cell]
        self._visited_cells.add(self._curr_cell)
        idx = 0
        while idx < len(cells):
            cell_idx = cells[idx]
            idx += 1
            edges = self._ug1.get_cell_edges(cell_idx)
            for edge_idx in range(len(edges)):
                adj_cell = self._ug1.get_cell_2d_edge_adjacent_cell(cell_idx, edge_idx)
                if adj_cell > -1 and adj_cell not in self._visited_cells:
                    cells.append(adj_cell)
                    self._visited_cells.add(adj_cell)
        self._connected_cells = set(cells)

    def _build_grid_from_connected_cells(self):
        """Builds a new grid from the connected cells."""
        pt_idx = set([p for idx in self._connected_cells for p in self._ug1.get_cell_points(idx)])
        # pt_idx = set()
        # for cell_idx in self._connected_cells:
        #     pts = self._ug1.get_cell_points(cell_idx)
        #     pt_idx.update(pts)

        old_to_new_idx = [-1] * self._ug1.point_count
        locs = self._ug1_locs
        new_locs = []
        new_idx = 0
        for i in range(len(locs)):
            if i in pt_idx:
                old_to_new_idx[i] = new_idx
                new_idx += 1
                new_locs.append(locs[i])

        self.logger.info('Splitting disjoint primary grid.')
        # refactor
        cell_idx = -1
        cnt = 0
        cellstream = self._ug1.cellstream
        new_cs = []
        while cnt < len(cellstream):
            cell_idx += 1
            if (cell_idx + 1) % self.progress_interval == 0:
                msg = f'Splitting primary grid - 500,000 cells processed. Current index {cell_idx}.'
                self.logger.info(msg)
            cell_type = cellstream[cnt]
            num_pts = cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cnt = end

            if cell_idx not in self._connected_cells:
                continue
            cell_pts = cellstream[start:end]
            new_cs.extend([cell_type, num_pts] + [old_to_new_idx[idx] for idx in cell_pts])

        co_builder = UGridBuilder()
        co_builder.set_unconstrained()
        co_builder.set_ugrid(UGrid(new_locs, new_cs))
        self._grids.append(co_builder.build_grid())


class _UGrid2dMerger:
    """Class to merge two 2d ugrids."""

    def __init__(self, grid_1, grid_2, pt_tol=1.0e-9, poly_buffer=None, logger=None, stitch_grids=False):
        """Initializes the class.

        Args:
            grid_1 (xms.constraint.UGrid_2d): the priority grid. This grid will be preserved.
            grid_2 (xms.constraint.UGrid_2d): the secondary grid. The parts of this grid that are outside of grid_1
                will be appended to grid 1.
            pt_tol (float): tolerance for comparing if points match locations
            poly_buffer (float): the distance to offset the boundary from grid_1
            logger (logging.logger): logger
            stitch_grids (bool): merge non-overlapping grids with shared boundary points
        """
        self._stitch_grids = stitch_grids
        self._grid_1 = grid_1
        self._grid_2 = grid_2
        self._ug1 = grid_1.ugrid
        self._ug2 = grid_2.ugrid
        self._ug2_locs = [[p[0], p[1], p[2]] for p in self._ug2.locations]
        self._ug2_pt_classify = None
        self.ug2_old_to_new_pt_idx = None
        self._g1_boundary = None
        self._g1_boundary_min_length = None
        self._g1_boundary_idxs = None
        self._g1_boundary_rtree = None
        self._g1_polygon = None
        self._g1_polygon_original = None
        self._g2_overlap_cells = None
        self._g2_overlap_polys = None
        self._g2_shared_edge_cells_out_g1 = None
        self._g2_shared_edge_cells_in_g1 = None
        self._match_pt_idx_g1_g2 = {}
        self._g2_remove_cells = []
        self._polys_to_mesh = []
        self._polys_to_mesh_interior = []
        self._polys_to_mesh_dicts = []
        self._out_locs = None
        self._out_cellstream = None
        self._pt_tol = pt_tol
        self._poly_buffer = poly_buffer
        self._nan_warning = False
        if logger:
            self.logger = logger
            self._own_logger = False
        else:
            self.logger = logging.getLogger('xms.tool')
            self._own_logger = True
        self.out_ugrid = None
        self.can_merge_cell_datasets = True
        self.progress_interval = 500000

    def merge_grids(self):
        """Merge the two grids."""
        self._check_grid_1_disjoint()
        self._calc_matching_grid_points()
        if not self._stitch_grids:
            self._classify_cells_from_matching_points()
            self._calc_grid_2_cells_intersected_by_grid_1_boundary()
            self._calc_grid_2_overlap_polygons()
        self._merge_grid_1_and_grid_2()
        if not self._stitch_grids:
            self._merge_overlap_polys()
        self._remove_detached_g2_points()
        self.out_ugrid = UGrid(self._out_locs, self._out_cellstream)

    def _check_grid_1_disjoint(self):
        """The input grid_1 should not be disjoint. If so abort the operation."""
        self.logger.info('Generating outer boundary polygon for priority ugrid/mesh.')
        ds_vals = [0] * self._ug1.cell_count
        poly_builder = GridCellToPolygonCoverageBuilder(co_grid=self._grid_1,
                                                        dataset_values=ds_vals,
                                                        wkt=None, coverage_name='temp',
                                                        logger=self.logger)
        out_polys = poly_builder.find_polygons()
        self._g1_boundary_idxs = out_polys[0][0][0]
        pts = self._ug1.locations
        self._g1_boundary = [(pts[idx][0], pts[idx][1], pts[idx][2]) for idx in out_polys[0][0][0]]
        self._g1_polygon = Polygon(self._g1_boundary)
        self._g1_polygon_original = Polygon(self._g1_boundary)
        g1_box = self._g1_polygon.bounds
        dist = xy_distance(g1_box, g1_box[2:])
        self.logger.info(f'Bounding box of primary UGrid/Mesh: {g1_box}')
        self.logger.info(f'Diagonal distance of bounding box of primary UGrid/Mesh: {dist}')
        min_length = 1.0e-3
        if len(self._g1_boundary) > 1:
            min_length = xy_distance(self._g1_boundary[0], self._g1_boundary[1])
            for i in range(2, len(self._g1_boundary)):
                dist = xy_distance(self._g1_boundary[i - 1], self._g1_boundary[i])
                if dist < min_length:
                    min_length = dist
        self._g1_boundary_min_length = min_length
        self.logger.info(f'Minimum edge length on primary UGrid/Mesh boundary: {min_length}')
        if self._poly_buffer is None:
            self._poly_buffer = min_length * 0.01
        # elif self._poly_buffer > min_length:
        #     msg = f'"Buffer distance for priority UGrid/Mesh" is larger than the minimum ' \
        #           f'edge length on the boundary of the priority UGrid/Mesh. ' \
        #           f'This value must be less than {min_length}.'
        #     raise MergeError(msg)
        self._g1_polygon = self._g1_polygon.buffer(self._poly_buffer, resolution=1)

    def _calc_matching_grid_points(self):
        """Find the points that match in grid_2 with the grid_1 boundary."""
        self.logger.info('Classifying point locations from the merge grid/mesh with the boundary of the '
                         'priority ugrid/mesh.')
        pt_in_poly = self._calc_g2_points_in_g1_polygon()
        pt_search_in_poly = [True] * len(pt_in_poly)
        if self._poly_buffer < self._pt_tol:
            poly_for_pt_search = self._g1_polygon_original.buffer(self._pt_tol + 1.0e-8, resolution=1)
            pt_search_in_poly = self._calc_g2_points_in_g1_polygon(poly_for_pt_search)
        boxes = [(p[0], p[1], p[0], p[1]) for p in self._g1_boundary]

        def generator_func():
            for j, b in enumerate(boxes):
                yield j, b, b

        self._g1_boundary_rtree = index.Index(generator_func())
        box = self._g1_polygon.bounds

        locs = self._ug2_locs
        self._ug2_pt_classify = [PT_OUT_POLY] * self._ug2.point_count

        g1_potential_matches = {}
        for i, p in enumerate(locs):
            if (i + 1) % self.progress_interval == 0:
                self.logger.info('Finding matching points - 500,000 points processed.')

            if p[0] < box[0] or p[1] < box[1] or p[0] > box[2] or p[1] > box[3]:
                continue
            if pt_in_poly[i] or pt_search_in_poly[i]:
                if pt_in_poly[i]:
                    self._ug2_pt_classify[i] = PT_IN_POLY
                # see if we match a boundary point
                idx = list(self._g1_boundary_rtree.nearest((p[0], p[1]), 1))[0]
                p_g1 = self._g1_boundary[idx]
                p_g1_idx = self._g1_boundary_idxs[idx]
                dist = xy_distance(p, p_g1)
                if dist < self._pt_tol:
                    my_list = g1_potential_matches.get(p_g1_idx, [])
                    my_list.append((dist, i, p_g1))
                    g1_potential_matches[p_g1_idx] = my_list
        for k, v in g1_potential_matches.items():
            p_g1_idx = k
            my_list = sorted(v)
            g2_idx = my_list[0][1]
            p_g1 = my_list[0][2]
            locs[g2_idx] = [p_g1[0], p_g1[1], p_g1[2]]
            self._match_pt_idx_g1_g2[p_g1_idx] = g2_idx
            self._ug2_pt_classify[g2_idx] = PT_MATCH_POLY_PT
            if len(my_list) > 1:
                msg = f'{len(my_list)} points from secondary UGrid/Mesh within point tolerance of {p_g1} ' \
                      f'in primary UGrid/Mesh.'
                self.logger.warning(msg)

    def _calc_g2_points_in_g1_polygon(self, polygon=None):
        """Does the point in polygon calc for grid_2 points in the grid_1 polygon."""
        input_poly = self._g1_polygon if polygon is None else polygon
        poly = np.asarray([(p[0], p[1]) for p in input_poly.exterior.coords])
        locs_2d = np.asarray([(p[0], p[1]) for p in self._ug2_locs])
        return geometry.run_parallel_points_in_polygon(locs_2d, poly)

    def _classify_cells_from_matching_points(self):
        """Find any edges that match between the grids and classify the adjacent cells as in or out."""
        self.logger.info('Classifying cells from matching points.')
        pts = [(p[0], p[1]) for p in self._g1_polygon.exterior.coords]
        mp_path = path.Path(pts)
        cells_in = set()
        cells_out = set()
        ug2_processed_cells = set()
        ug1_match_set = set(self._match_pt_idx_g1_g2.keys())
        ug2_match_set = set(self._match_pt_idx_g1_g2.values())
        i = 0
        for g1_idx1, g2_idx1 in self._match_pt_idx_g1_g2.items():
            i += 1
            if (i + 1) % self.progress_interval == 0:
                self.logger.info('Classify cells from matching points - 500,000 points processed. Current idx {i}.')
            g1_adj_pts = set(self._ug1.get_point_adjacent_points(g1_idx1)).intersection(ug1_match_set)
            for g1_adj_pt in g1_adj_pts:
                g2_idx2 = self._match_pt_idx_g1_g2[g1_adj_pt]
                edge = sorted([g2_idx1, g2_idx2])
                g2_adj_cells = self._ug2.get_edge_adjacent_cells(edge)
                for g2_cell in g2_adj_cells:
                    if g2_cell in ug2_processed_cells:
                        continue
                    ug2_processed_cells.add(g2_cell)
                    cell_pts = set(self._ug2.get_cell_points(g2_cell)) - ug2_match_set
                    if len(cell_pts) > 0:
                        pt_classes = {self._ug2_pt_classify[idx] for idx in cell_pts}
                        if PT_IN_POLY in pt_classes:
                            cells_in.add(g2_cell)
                        else:
                            cells_out.add(g2_cell)
                    else:
                        _, center = self._ug2.get_cell_centroid(g2_cell)
                        if mp_path.contains_point((center[0], center[1])):
                            cells_in.add(g2_cell)
                        else:
                            cells_out.add(g2_cell)
        i = 0
        ug1_locs = self._ug1.locations
        for g2_idx1 in ug2_match_set:
            i += 1
            if (i + 1) % self.progress_interval == 0:
                self.logger.info('Classify cells from matching points - 500,000 points processed. Current index {i}.')
            g2_cells = self._ug2.get_point_adjacent_cells(g2_idx1)
            for g2_cell in g2_cells:
                if g2_cell in ug2_processed_cells:
                    continue
                ug2_processed_cells.add(g2_cell)
                cell_pts = self._ug2.get_cell_points(g2_cell)
                pt_classes = [self._ug2_pt_classify[idx] for idx in cell_pts]
                n_point_match = pt_classes.count(PT_MATCH_POLY_PT)
                if n_point_match == 1 and PT_IN_POLY not in pt_classes:
                    cells_out.add(g2_cell)
                elif n_point_match == len(pt_classes):
                    cells_in.add(g2_cell)
                else:
                    _, poly = self._ug2.get_cell_plan_view_polygon(g2_cell)
                    poly = [(p[0], p[1]) for p in poly]
                    poly = Polygon(poly)
                    o = list(self._g1_boundary_rtree.intersection(poly.bounds))
                    g1_pts = [self._g1_boundary_idxs[idx] for idx in o]
                    # remove pts that have matching grid_2 points
                    g1_pts = [idx for idx in g1_pts if idx not in ug1_match_set]
                    check_pts = []
                    for idx in g1_pts:
                        poly_pt = Point(ug1_locs[idx][0], ug1_locs[idx][1]).buffer(self._poly_buffer)
                        check_pts.append(poly_pt)
                    # g1_pts = [Point((ug1_locs[idx][0], ug1_locs[idx][1])) for idx in g1_pts]
                    in_poly = [poly.intersects(p) for p in check_pts]
                    if True in in_poly:
                        cells_in.add(g2_cell)
                    else:
                        cells_out.add(g2_cell)
        self._g2_shared_edge_cells_in_g1 = list(cells_in)
        self._g2_shared_edge_cells_out_g1 = list(cells_out)

    def _calc_grid_2_cells_intersected_by_grid_1_boundary(self):
        """Find the cells in grid_2 intersected by the grid_1 boundary."""
        self.logger.info('Finding merge ugrid/mesh cells that are intersected by priority ugrid/mesh boundary.')

        locs = self._ug2_locs
        ug_elevations = [pt[2] for pt in locs]
        ug_activity = [1] * len(ug_elevations)
        self.logger.info('Building "UGrid2dPolylineDataExtractor". This can take minutes for large grids.')
        polyline_extractor = UGrid2dPolylineDataExtractor(ugrid=self._ug2, scalar_location='points')
        polyline_extractor.set_grid_scalars(ug_elevations, ug_activity, 'points')
        pts = [(p[0], p[1], 0) for p in self._g1_polygon.exterior.coords]
        polyline_extractor.set_polyline(pts)
        _ = polyline_extractor.extract_data()
        remove_set = set([-1] + self._g2_shared_edge_cells_in_g1 + self._g2_shared_edge_cells_out_g1)
        self._g2_overlap_cells = set(polyline_extractor.cell_indexes) - remove_set

        # bulk load segments into the RTree
        pts = [(p[0], p[1], 0) for p in self._g1_polygon.exterior.coords]
        pairs = [(pts[i - 1], pts[i]) for i in range(1, len(pts))]
        box_min = [(min(p[0][0], p[1][0]), min(p[0][1], p[1][1])) for p in pairs]
        box_max = [(max(p[0][0], p[1][0]), max(p[0][1], p[1][1])) for p in pairs]
        boxes = [(p[0][0], p[0][1], p[1][0], p[1][1]) for p in zip(box_min, box_max)]

        def generator_func():
            for i, b in enumerate(boxes):
                yield i, b, b

        # create rtree of the g1_polygon lines
        g1_poly_line_rtree = index.Index(generator_func())

        # create rtree of the g1_polygon lines
        # g1_poly_line_rtree = index.Index()
        # pts = [(p[0], p[1], 0) for p in self._g1_polygon.exterior.coords]
        # lines = []
        # for i in range(1, len(pts)):
        #     lines.append(LineString((pts[i - 1], pts[i])))
        #     g1_poly_line_rtree.insert(i - 1, lines[-1].bounds)
        # the extractor misses cells sometimes
        processed_points = set()
        processed_cells = set()
        overlap = list(self._g2_overlap_cells)
        adj_func = self._ug2.get_point_adjacent_cells
        lines = dict()
        overlap_idx = 0
        while overlap_idx < len(overlap):
            cell_idx = overlap[overlap_idx]
            overlap_idx += 1
            pts = set(self._ug2.get_cell_points(cell_idx)) - processed_points
            for p in pts:
                adj_cells = set(adj_func(p)) - processed_cells - self._g2_overlap_cells - remove_set
                for adj in adj_cells:
                    _, poly = self._ug2.get_cell_plan_view_polygon(adj)
                    poly = [(p[0], p[1]) for p in poly]
                    poly = Polygon(poly)
                    potential_intersections = list(g1_poly_line_rtree.intersection(poly.bounds))
                    for line_idx in potential_intersections:
                        the_line = lines.get(line_idx, LineString(pairs[line_idx]))
                        if the_line.intersects(poly):
                            self._g2_overlap_cells.add(adj)
                            overlap.append(adj)
                            break
                processed_cells.update(adj_cells)
            processed_points.update(pts)

        self.logger.info('Classifying merge ugrid/mesh cells as in or out of the priority ugrid/mesh.')
        cell_idx = -1
        cnt = 0
        cellstream = self._ug2.cellstream
        while cnt < len(cellstream):
            cell_idx += 1
            if (cell_idx + 1) % self.progress_interval == 0:
                self.logger.info(f'Classifying cells - 500,000 cells processed. Current index {cell_idx}.')
            # cell_type = cellstream[cnt]
            num_pts = cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cell_pts = cellstream[start:end]
            cnt = end
            inside = {self._ug2_pt_classify[idx] for idx in cell_pts}
            if PT_IN_POLY in inside:
                self._g2_remove_cells.append(cell_idx)
        self._g2_remove_cells.extend(self._g2_shared_edge_cells_in_g1)
        for idx in self._g2_overlap_cells:
            self._g2_remove_cells.append(idx)

    def _calc_grid_2_overlap_polygons(self):
        """Create polygons from the overlapped cells in grid_2."""
        self.logger.info('Calculating overlap polygons between merge ugrid/mesh and priority ugrid/mesh.')
        self._g2_overlap_polys = self._calc_grid_2_polygons()
        if self._g2_overlap_polys is not None:
            op = PolygonOperator(self._g1_boundary, self._g1_boundary_idxs, self._g1_boundary_rtree,
                                 self._ug2_locs, self._ug2_pt_classify, self._g2_overlap_polys,
                                 self._poly_buffer, self._pt_tol, self.logger)
            op.calc_overlap_polys()
            self._polys_to_mesh = op.polys_to_mesh
            self._polys_to_mesh_dicts = op.polys_to_mesh_dicts
            self._polys_to_mesh_interior = op.polys_to_mesh_interior

    def _calc_grid_2_polygons(self):
        """Get the polygons for the grid_2 cells that will be removed.

        Returns:
            (list): the polygons from grid_2
        """
        self.logger.info('Calculating merge ugrid/mesh polygons for transition to priority ugrid/mesh.')
        ds_vals = [0] * self._ug2.cell_count
        remove_cells = set(self._g2_remove_cells)
        for i in range(self._ug2.cell_count):
            if i in remove_cells:
                ds_vals[i] = 1
        poly_builder = GridCellToPolygonCoverageBuilder(co_grid=self._grid_2,
                                                        dataset_values=ds_vals,
                                                        wkt=None, coverage_name='temp')
        out_polys = poly_builder.find_polygons()
        if 1 in out_polys:
            g2_remove_cells = set(self._g2_remove_cells)
            for cur_poly in out_polys[1]:
                if len(cur_poly) > 1:  # there are holes in this polygon. Mark those cells to be deleted
                    visited_pts = set(cur_poly[0])
                    for pts_to_visit in cur_poly[1:]:
                        pts_idx = 0
                        while pts_idx < len(pts_to_visit):
                            idx = pts_to_visit[pts_idx]
                            pts_idx += 1

                            if idx in visited_pts:
                                continue
                            visited_pts.add(idx)
                            # change the point classification so this point will be deleted
                            if self._ug2_pt_classify[idx] != PT_OUT_POLY:
                                continue
                            self._ug2_pt_classify[idx] = PT_IN_POLY
                            for cell in self._ug2.get_point_adjacent_cells(idx):
                                g2_remove_cells.add(cell)
                            # add neighbor points to visit
                            adj_pts = set(self._ug2.get_point_adjacent_points(idx)) - visited_pts
                            for adj_idx in adj_pts:
                                if self._ug2_pt_classify[adj_idx] == PT_OUT_POLY:
                                    pts_to_visit.append(adj_idx)
                        cur_poly = [cur_poly[0]]
            self._g2_remove_cells = list(g2_remove_cells)
            return out_polys[1]
        return None

    def _merge_grid_1_and_grid_2(self):
        """Merges grid_1 and grid_2 removing cells from grid_2."""
        self.logger.info('Merging priority ugrid/mesh and merge ugrid/mesh.')
        self._out_locs = [(p[0], p[1], p[2]) for p in self._ug1.locations]
        new_idx = len(self._out_locs)
        ug2_locs = self._ug2_locs
        match_pt_idx_g2_g1 = {val: key for key, val in self._match_pt_idx_g1_g2.items()}
        self.ug2_old_to_new_pt_idx = [-1] * self._ug2.point_count
        remove_set = set(self._g2_remove_cells)
        for i in range(self._ug2.point_count):
            if (i + 1) % self.progress_interval == 0:
                self.logger.info('Merging grids - 500,000 points processed.')
            if i in match_pt_idx_g2_g1:
                self.ug2_old_to_new_pt_idx[i] = match_pt_idx_g2_g1[i]
            elif self._ug2_pt_classify[i] == PT_OUT_POLY:
                self.ug2_old_to_new_pt_idx[i] = new_idx
                new_idx += 1
                p = ug2_locs[i]
                self._out_locs.append((p[0], p[1], p[2]))

        self._out_cellstream = list(self._ug1.cellstream)
        ug2_cellstream = self._ug2.cellstream
        cell_idx = -1
        cnt = 0
        while cnt < len(ug2_cellstream):
            cell_idx += 1
            if (cell_idx + 1) % self.progress_interval == 0:
                self.logger.info('Merging grids - 500,000 cells processed.')
            cell_type = ug2_cellstream[cnt]
            num_pts = ug2_cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cell_pts = ug2_cellstream[start:end]
            cnt = end
            old_to_new = [self.ug2_old_to_new_pt_idx[idx] for idx in cell_pts]
            if cell_idx not in remove_set and -1 not in old_to_new:
                cs = [cell_type, num_pts] + old_to_new
                self._out_cellstream.extend(cs)

    def _merge_overlap_polys(self):
        """Mesh the overlap polygons and merge them into the output ugrid."""
        self.logger.info('Merging overlap areas between input ugrids/meshes.')
        self._remove_intersecting_overlap_polys()
        for i, poly in enumerate(self._polys_to_mesh):
            poly_pts = []
            # last_pt = None
            for p in poly:
                poly_pts.append(p)
                # if last_pt is not None and (p[0], p[1]) == last_pt:
                #     poly_pts.pop()
                # last_pt = (p[0], p[1])
            # print(f'merging poly idx: {i}')
            hole = self._polys_to_mesh_interior[i]
            grid_dicts = self._polys_to_mesh_dicts[i]
            linear = InterpLinear(points=poly_pts)
            inner = None
            if len(hole) > 0:
                inner = [hole]
            meshing_inputs = [
                meshing.PolyInput(outside_polygon=poly_pts, inside_polygons=inner, elev_function=linear, bias=1.0),
            ]
            ugrid = generate_mesh(polygon_inputs=meshing_inputs)
            self._merge_poly_grid_with_output_grid(ugrid, grid_dicts[0], grid_dicts[1], poly_pts)

    def _merge_poly_grid_with_output_grid(self, ugrid, grid_1_dict, grid_2_dict, poly_pts):
        """Merge the grid that came from an overlap polygon with the output grid.

        Args:
            ugrid (xms.grid.UGrid): the ugrid
            grid_1_dict (dict): (x, y) as key, index as value
            grid_2_dict (dict): (x, y) as key, index as value
            poly_pts (list): locations of the outside polygon that was meshed to make the input ugrid
        """
        self.can_merge_cell_datasets = False
        locs = ugrid.locations
        pt_idxs = [-1] * len(locs)
        # determine the point index for each point in the output grid
        poly_set = {(p[0], p[1]) for p in poly_pts}
        new_idx = len(self._out_locs)
        for i, pt in enumerate(locs):
            p = (pt[0], pt[1])
            if p in grid_1_dict:
                idx = grid_1_dict[p]
            elif p in grid_2_dict:
                idx = self.ug2_old_to_new_pt_idx[grid_2_dict[p]]
            elif p in poly_set:
                idx = -1
            else:
                idx = new_idx
                new_idx += 1
                self._out_locs.append((pt[0], pt[1], pt[2]))
            pt_idxs[i] = idx

        cellstream = ugrid.cellstream
        cell_idx = -1
        cnt = 0
        while cnt < len(cellstream):
            cell_idx += 1
            cell_type = cellstream[cnt]
            num_pts = cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cell_pts = cellstream[start:end]
            cnt = end
            old_to_new = [pt_idxs[idx] for idx in cell_pts]

            # skip this cell it is connected to a boundary point that is not part of grid1 or grid 2
            if -1 in old_to_new:
                continue

            cs = [cell_type, num_pts] + old_to_new
            self._out_cellstream.extend(cs)

    def _remove_intersecting_overlap_polys(self):
        """Remove any overlap polygons that intersect and log warning information."""
        self.logger.info('Checking transition area polygons for intersections.')
        check_polys = [Polygon(poly) for poly in self._polys_to_mesh]
        poly_intersect = []
        for i, poly in enumerate(check_polys):
            for j in range(i + 1, len(check_polys)):
                # print(f'i: {i}, j: {j}')
                p2 = check_polys[j]
                if poly.overlaps(p2):
                    poly_intersect.append((i, j))

        remove_polys = set()
        g1_poly = Polygon(self._g1_boundary)
        for idx1, idx2 in poly_intersect:
            if not (idx1 in remove_polys or idx2 in remove_polys):
                p1 = check_polys[idx1]
                p1_pts = [Point((pt[0], pt[1])) for pt in p1.exterior.coords]
                p1_min_dist = sys.float_info.max
                for p in p1_pts:
                    dist = g1_poly.distance(p)
                    if 0.0 < dist < p1_min_dist:
                        p1_min_dist = dist
                p2 = check_polys[idx2]
                p2_pts = [Point((pt[0], pt[1])) for pt in p2.exterior.coords]
                p2_min_dist = sys.float_info.max
                for p in p2_pts:
                    dist = g1_poly.distance(p)
                    if 0.0 < dist < p2_min_dist:
                        p2_min_dist = dist
                remove_idx = idx1
                if p1_min_dist < p2_min_dist:
                    remove_idx = idx2
                remove_polys.add(remove_idx)

        remove_polys = list(remove_polys)
        remove_polys.sort()
        remove_polys.reverse()
        for idx in remove_polys:
            self._polys_to_mesh.pop(idx)
            msg = f'Unable to automatically re-mesh this polygon: {check_polys[idx]}. Index: {idx}.'
            self.logger.warning(msg)

    def _remove_detached_g2_points(self):
        """Remove any points from grid_2 that are not attached to cells."""
        self.logger.info('Checking disjoint points from merge ugrid/mesh.')
        remove_set = set(self._g2_remove_cells)

        check_pts = [i for i, idx in enumerate(self.ug2_old_to_new_pt_idx) if idx != -1]
        pt_to_remove = set()
        for i in check_pts:
            ug2_cells = set(self._ug2.get_point_adjacent_cells(i))
            if len(ug2_cells) > 0 and len(ug2_cells - remove_set) < 1:
                pt_to_remove.add(self.ug2_old_to_new_pt_idx[i])
        if len(pt_to_remove) < 1:
            return

        # make sure these points are not in the output cellstream
        cellstream = self._out_cellstream
        cell_idx = -1
        cnt = 0
        cell_pts = set()
        while cnt < len(cellstream):
            cell_idx += 1
            if (cell_idx + 1) % self.progress_interval == 0:
                self.logger.info('Disjoint points - 500,000 cells processed.')
            # cell_type = cellstream[cnt]
            num_pts = cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cnt = end
            cell_pts.update(cellstream[start:end])
        pt_to_remove = pt_to_remove - cell_pts
        if len(pt_to_remove) < 1:
            return

        new_locs = []
        out_old_to_new_pt_idx = [-1] * len(self._out_locs)
        for i in range(len(out_old_to_new_pt_idx)):
            if (i + 1) % self.progress_interval == 0:
                self.logger.info('Disjoint points - 500,000 points processed.')
            if i not in pt_to_remove:
                out_old_to_new_pt_idx[i] = len(new_locs)
                new_locs.append(self._out_locs[i])
        self._out_locs = new_locs

        new_cellstream = []
        cellstream = self._out_cellstream
        cell_idx = -1
        cnt = 0
        while cnt < len(cellstream):
            cell_idx += 1
            if (cell_idx + 1) % self.progress_interval == 0:
                self.logger.info('Disjoint points - 500,000 cells processed.')
            cell_type = cellstream[cnt]
            num_pts = cellstream[cnt + 1]
            start = cnt + 2
            end = start + num_pts
            cell_pts = cellstream[start:end]
            cnt = end
            new_cellstream.extend([cell_type, num_pts] + [out_old_to_new_pt_idx[idx] for idx in cell_pts])
        self._out_cellstream = new_cellstream
