"""RelaxUGridPoints class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import time

# 2. Third party modules
import numba
from numba import jit, njit
import numpy as np
import pandas as pd

# 3. Aquaveo modules
from xms.constraint.ugrid_builder import UGridBuilder
from xms.grid.triangulate.tin import Tin
from xms.grid.ugrid.ugrid import UGrid
from xms.interp.interpolate.interp_linear import InterpLinear

# 4. Local modules
from xms.tool.algorithms.ugrids.cell_2d_quality_metrics import triangle_quality

AREA_TOL = 3.0e-16


@jit(nopython=True)
def compute_triangle_quality(idx, pts, tris):
    """Computes quality measures of a triangle.

    Args:
        idx (int): The triangle index.
        pts (ndarray): The points.
        tris (ndarray): The triangles.

    Returns:
        (q_alpha_min, q_Ll, q_ALS, q_Rr, q_Lr, q_Lh): The quality measures
    """
    tri_idx = idx * 3
    idx0, idx1, idx2 = tris[tri_idx], tris[tri_idx + 1], tris[tri_idx + 2]
    pt1, pt2, pt3 = pts[idx0], pts[idx1], pts[idx2]
    return triangle_quality(pt1, pt2, pt3)[:6]


@njit(parallel=True)
def compute_triangle_quality_parallel(pts, tris):
    """Computes the centroids of triangles.

    Args:
        pts (ndarray): The points.
        tris (ndarray): The triangles.

    Returns:
        (ndarray): The triangle quality.
    """
    quality = np.zeros((tris.shape[0] // 3, 6), dtype=np.float64)
    for i in numba.prange(quality.shape[0]):
        quality[i] = compute_triangle_quality(i, pts, tris)
    return quality


@jit(nopython=True)
def circum_circle_center(pt1, pt2, pt3):
    """Computes the circum circle center of a triangle.

    Args:
        pt1 (ndarray): The first point.
        pt2 (ndarray): The second point.
        pt3 (ndarray): The third point.

    Returns:
        (x, y): The circum circle center.
    """
    det11 = pt1[0] - pt2[0]
    det12 = pt1[1] - pt2[1]
    det13 = det11 * (pt1[0] + pt2[0]) / 2.0 + det12 * (pt1[1] + pt2[1]) / 2.0

    det21 = pt3[0] - pt2[0]
    det22 = pt3[1] - pt2[1]
    det23 = det21 * (pt3[0] + pt2[0]) / 2.0 + det22 * (pt3[1] + pt2[1]) / 2.0
    # compute determinant
    determinate = det11 * det22 - det12 * det21
    xc = (det13 * det22 - det23 * det12) / determinate
    yc = (det11 * det23 - det21 * det13) / determinate
    return xc, yc


@jit(nopython=True)
def compute_centroid_and_area(tri_idx, pts, tris, locked_pts, lloyd=False):
    """Computes the centroids of triangles.

    Args:
        tri_idx (int): The triangle index.
        pts (ndarray): The points.
        tris (ndarray): The triangles.
        locked_pts (ndarray): The points that are locked.
        lloyd (bool): If True, use Lloyd's method for circum centers.

    Returns:
        (x, y): The centroid.
    """
    idx = tri_idx * 3
    idx0, idx1, idx2 = tris[idx], tris[idx + 1], tris[idx + 2]
    p0, p1, p2 = pts[idx0], pts[idx1], pts[idx2]
    x1, x2, x3 = p0[0], p1[0], p2[0]
    y1, y2, y3 = p0[1], p1[1], p2[1]
    # (1/2) |x1(y2 − y3) + x2(y3 − y1)+ x3(y1 − y2)|
    area = 0.5 * (x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2))
    # # just use bary center
    # x = (x1 + x2 + x3) / 3.0
    # y = (y1 + y2 + y3) / 3.0
    # bary center - use if 2 point on the triangle are locked or the area is less than tolerance
    tri_pts_locked = np.asarray((locked_pts[idx0], locked_pts[idx1], locked_pts[idx2]))
    if not lloyd or area < AREA_TOL or np.count_nonzero(tri_pts_locked) > 0:  # bary center
        x = (x1 + x2 + x3) / 3.0
        y = (y1 + y2 + y3) / 3.0
    else:  # circum circle center
        x, y = circum_circle_center(p0, p1, p2)
    return x, y, area


@njit(parallel=True)
def compute_centroids_and_area_parallel(pts, tris, locked_pts, lloyd=False):
    """Computes the centroids of triangles.

    Args:
        pts (ndarray): The points.
        tris (ndarray): The triangles.
        locked_pts (ndarray): The points that are locked.
        lloyd (bool): If True, use Lloyd's method for circum centers.

    Returns:
        (ndarray): The centroids and area.
    """
    centroids = np.zeros((tris.shape[0] // 3, 3), dtype=np.float64)
    for i in numba.prange(centroids.shape[0]):
        centroids[i] = compute_centroid_and_area(i, pts, tris, locked_pts, lloyd)
    return centroids


@jit(nopython=True)
def area_relax_new_loc(pt_idx, pts, adj_tri_lookup, adj_tris, tri_center_area):
    """Computes the new location of a point using the area method.

    Args:
        pt_idx (int): The triangle index.
        pts (ndarray): The points.
        adj_tri_lookup (ndarray): Look up array for the triangles connected to a point.
        adj_tris (ndarray): The triangles connected to a point.
        tri_center_area (ndarray): The centroids and area of the triangles.

    Returns:
        (x, y, z): new location
    """
    new_loc = np.zeros(3, dtype=np.float64)
    new_loc[0], new_loc[1], new_loc[2] = pts[pt_idx][0], pts[pt_idx][1], pts[pt_idx][2]
    sum_x = sum_y = sum_area = 0.0
    start, num_adj = adj_tri_lookup[pt_idx]
    for i in range(num_adj):
        tri = adj_tris[start + i]
        sum_x += tri_center_area[tri][2] * tri_center_area[tri][0]
        sum_y += tri_center_area[tri][2] * tri_center_area[tri][1]
        sum_area += tri_center_area[tri][2]
    if sum_area > 0.0:
        new_loc[0], new_loc[1] = sum_x / sum_area, sum_y / sum_area
    return new_loc


@njit(parallel=True)
def area_relax_parallel(locked_pts, pts, adj_tri_lookup, adj_tris, tri_center_area):
    """Computes the new point locations using the area method.

    Args:
        locked_pts (ndarray): The locked nodes.
        pts (ndarray): The points.
        adj_tri_lookup (ndarray): Look up array for the triangles connected to a point.
        adj_tris (ndarray): The triangles connected to a point.
        tri_center_area (ndarray): The centroids and area of the triangles.

    Returns:
        (ndarray): new locations
    """
    new_pts = np.zeros(pts.shape, dtype=np.float64)
    for i in numba.prange(pts.shape[0]):
        if locked_pts[i] == 1:
            new_pts[i] = pts[i]
            continue
        new_pts[i] = area_relax_new_loc(i, pts, adj_tri_lookup, adj_tris, tri_center_area)
    return new_pts


@jit(nopython=True)
def angle_relax_new_loc(pt_idx, pts, tris, adj_tri_lookup, adj_tris):
    """Computes the new point locations using the angle method.

    Args:
        pt_idx (int): Index of the point
        pts (ndarray): The points.
        tris (ndarray): The triangles.
        adj_tri_lookup (ndarray): Look up array for the triangles connected to a point.
        adj_tris (ndarray): The triangles connected to a point.

    Returns:
        (ndarray): new location
    """
    new_loc = np.zeros(3, dtype=np.float64)
    start, num_adj = adj_tri_lookup[pt_idx]
    target_angle = 2.0 * np.pi / num_adj
    sum_x = sum_y = 0.0
    for i in range(num_adj):
        tri_idx = adj_tris[start + i] * 3
        tri_pt_idx = tris[tri_idx:tri_idx + 3]
        if tri_pt_idx[0] == pt_idx:
            idx0, idx1, idx2 = tri_pt_idx[0], tri_pt_idx[1], tri_pt_idx[2]
        elif tri_pt_idx[1] == pt_idx:
            idx0, idx1, idx2 = tri_pt_idx[1], tri_pt_idx[2], tri_pt_idx[0]
        else:
            idx0, idx1, idx2 = tri_pt_idx[2], tri_pt_idx[0], tri_pt_idx[1]
        # set up points
        pt_x, pt_y = pts[idx0][0], pts[idx0][1]
        p1_x, p1_y = pts[idx1][0], pts[idx1][1]
        p2_x, p2_y = pts[idx2][0], pts[idx2][1]
        # find length of line bisecting p1 and p2
        dx = p1_x - p2_x
        dy = p1_y - p2_y
        c = np.sqrt(dx * dx + dy * dy)
        sx = p1_x + dx / 2.0
        sy = p1_y + dy / 2.0
        # compute length of triangle sides
        a = (c / 2.0) / np.sin(target_angle / 2.0)
        d = a * np.cos(target_angle / 2.0)
        radius = (a * a) / (2.0 * d)
        # find candidate point p3 and circle centroid
        dx /= c
        dy /= c
        p3_x = sx + d * dy
        p3_y = sy - d * dx
        dx3, dy3 = p3_x - pt_x, p3_y - pt_y
        da = dx3 * dx3 + dy3 * dy3
        p3_x = sx - d * dy
        p3_y = sy + d * dx
        dx3, dy3 = p3_x - pt_x, p3_y - pt_y
        db = dx3 * dx3 + dy3 * dy3
        if da < db:
            centroid_x = sx + (d - radius) * dy
            centroid_y = sy - (d - radius) * dx
        else:
            centroid_x = sx - (d - radius) * dy
            centroid_y = sy + (d - radius) * dx
        # find p3 along circle, radius away from centroid
        dx = pt_x - centroid_x
        dy = pt_y - centroid_y
        mag = np.sqrt(dx * dx + dy * dy)
        dx /= mag
        dy /= mag
        # find new point
        p3_x = centroid_x + radius * dx
        p3_y = centroid_y + radius * dy
        # sum new point's x and y
        sum_x += p3_x
        sum_y += p3_y
    new_loc[0], new_loc[1], new_loc[2] = sum_x / num_adj, sum_y / num_adj, pts[pt_idx][2]
    return new_loc


@njit(parallel=True)
def angle_relax_parallel(locked_pts, pts, tris, adj_tri_lookup, adj_tris):
    """Computes the new point locations using the angle method.

    Args:
        locked_pts (ndarray): The locked nodes.
        pts (ndarray): The points.
        tris (ndarray): The triangles.
        adj_tri_lookup (ndarray): Look up array for the triangles connected to a point.
        adj_tris (ndarray): The triangles connected to a point.

    Returns:
        (ndarray): new locations
    """
    new_pts = np.zeros(pts.shape, dtype=np.float64)
    for i in numba.prange(pts.shape[0]):
        if locked_pts[i] == 1:
            new_pts[i] = pts[i]
            continue
        new_pts[i] = angle_relax_new_loc(i, pts, tris, adj_tri_lookup, adj_tris)
    return new_pts


@jit(nopython=True)
def pts_adjacent_pt(pt_idx, adj_tri_lookup, adj_tris, tris):
    """Computes the points adjacent to a point.

    Args:
        pt_idx (int): The point index in the pts array.
        adj_tri_lookup (ndarray): Look up array for the triangles connected to a point.
        adj_tris (ndarray): The triangles connected to a point.
        tris (ndarray): The triangles.

    Returns:
        (ndarray): The points adjacent to the point.
    """
    start, num_adj = adj_tri_lookup[pt_idx]
    ret = np.zeros(num_adj * 3, dtype=np.int64)
    for i in range(num_adj):
        tri_idx = adj_tris[start + i] * 3
        idx = i * 3
        for j in range(3):
            ret[idx + j] = tris[tri_idx + j]
    return ret


@jit(nopython=True)
def spring_relax_new_loc(pt_idx, pts, adj_pt_lookup, adj_pts, pt_sizes):
    """Computes the new point location using the spring method.

    Args:
        pt_idx (int): The point index in the pts array.
        pts (ndarray): The points.
        adj_pt_lookup (ndarray): Look up array for the points connected to a point.
        adj_pts (ndarray): The pts connected to a point.
        pt_sizes (ndarray): The target triangle edge size at a point

    Returns:
        (ndarray): new location
    """
    new_loc = np.zeros(3, dtype=np.float64)
    spring_stiffness = 1.0
    fx = 0.0
    fy = 0.0
    start, num_adj = adj_pt_lookup[pt_idx]
    num_pts_factor = 1.7 / num_adj
    p0 = pts[pt_idx]
    size0 = pt_sizes[pt_idx]
    for i in range(num_adj):
        adj_pt_idx = adj_pts[start + i]
        p1 = pts[adj_pt_idx]
        size1 = pt_sizes[adj_pt_idx]
        target_dist = 0.5 * (size0 + size1)
        dx = p0[0] - p1[0]
        dy = p0[1] - p1[1]
        dist = np.sqrt(dx * dx + dy * dy)
        force = spring_stiffness * (target_dist - dist)
        fx += force * dx / dist
        fy += force * dy / dist
    fx *= num_pts_factor
    fy *= num_pts_factor
    new_loc[0] = p0[0] + fx
    new_loc[1] = p0[1] + fy
    new_loc[2] = p0[2]
    return new_loc


@njit(parallel=True)
def spring_relax_parallel(locked_pts, pts, adj_pts_lookup, adj_pts, pt_sizes):
    """Computes the new point locations using the spring method.

    Args:
        locked_pts (ndarray): The locked nodes.
        pts (ndarray): The points.
        adj_pts_lookup (ndarray): Look up array for the points connected to a point.
        adj_pts (ndarray): The pts connected to a point.
        pt_sizes (ndarray): The target triangle edge size at a point

    Returns:
        (ndarray): new locations
    """
    new_pts = np.zeros(pts.shape, dtype=np.float64)
    for i in numba.prange(pts.shape[0]):
        if locked_pts[i] == 1:
            new_pts[i] = pts[i]
            continue
        new_pts[i] = spring_relax_new_loc(i, pts, adj_pts_lookup, adj_pts, pt_sizes)
    return new_pts


@jit(nopython=True)
def check_relaxed_new_pts_parallel(locked_pts, pts, new_pts, tris, converge_d2):
    """Checks if new points are in valid locations and computes max distance moved.

    Args:
        locked_pts (ndarray): The locked nodes.
        pts (ndarray): The points.
        new_pts (ndarray): The new locations.
        tris (ndarray): The triangles.
        converge_d2 (float): Distance squared to converge.

    Returns:
        (new_tin_pts, dist2): new locations, max distance moved
    """
    new_tin_pts = np.copy(new_pts)
    # compute the distance moved for all the points
    dist2 = np.zeros(pts.shape[0], dtype=np.float64)
    for i in numba.prange(pts.shape[0]):
        if locked_pts[i] == 1:
            continue
        dx = pts[i][0] - new_tin_pts[i][0]
        dy = pts[i][1] - new_tin_pts[i][1]
        dist2[i] = dx * dx + dy * dy
    max_d2 = dist2.max()

    # make sure the new point locations do not create invalid triangles
    # if bad triangles are created then we iteratively move the points 1/2 way back to the original location
    while True:
        centroids = compute_centroids_and_area_parallel(new_tin_pts, tris, locked_pts)
        areas = centroids[:, 2]
        if not np.any(areas < AREA_TOL):  # if all triangles have valid area then we are done
            break

        # check if the max distance moved is less than the convergence distance
        if max_d2 < converge_d2:
            new_tin_pts = np.copy(pts)
            max_d2 = 0.0
            break

        max_d2 *= 0.5  # reduce the max distance moved by half
        for i in numba.prange(new_tin_pts.shape[0]):  # move the points 1/2 way back to the original location
            if locked_pts[i] == 1:
                continue
            new_tin_pts[i][0] = 0.5 * (new_tin_pts[i][0] + pts[i][0])
            new_tin_pts[i][1] = 0.5 * (new_tin_pts[i][1] + pts[i][1])

    return new_tin_pts, max_d2


class RelaxUGridPoints:
    """Relaxes ugrid points to create higher quality cells."""

    def __init__(self, logger, co_grid, relax_type, num_iterations, converge_dist, factor, locked_pts, pt_sizes=None,
                 optimize_triangles=False):
        """

        Initializes the class.

        Args:
            logger: The logger that outputs to the user
            co_grid (): the tin
            relax_type (str): 'Area', 'Angle', 'Spring', 'Lloyd'
            num_iterations (int): Number of relaxation iterations to perform
            converge_dist (float): Stop iterating if the max distance a point moves is less than this value
            factor (float): Factor used if we are in lat lon coordinates
            locked_pts (np.array): The points that should not be moved
            pt_sizes (np.array): The size function value at each point
            optimize_triangles (bool): If True, optimize the triangles before relaxing the points.
        """
        self.logger = logger
        self.out_co_grid = None
        self._co_grid = co_grid
        self._tin = None
        self._relax_type = relax_type
        self._num_iter = num_iterations
        self._factor = 1.0 / factor
        dist = converge_dist * factor
        self._converge_dist = dist
        self._converge_dist2 = dist * dist
        self._locked_points = locked_pts
        self._pt_sizes = pt_sizes
        self._optimize_triangles = optimize_triangles
        self._lloyd = False if relax_type != 'Lloyd' else True

        self._pts = None
        self._tris = None
        self._adj_tri_lookup = None
        self._adj_tris = None
        self._centroids = None
        self._adj_pt_lookup = None
        self._adj_pts = None
        self._new_pts = None
        self._max_dist = None
        self._converged = False
        self._new_cell_stream = []
        self._tri_type = None
        self._stats = {
            'Interation': [0],
            'Maximum distance moved': [0.0],
            'Quality Alpha Min': [],
            'Quality L1': [],
            'Quality ALS': [],
            'Quality Rr': [],
            'Quality Lr': [],
            'Quality Lh': []
        }

    def relax(self):
        """Relaxes the points in the grid."""
        self._build_tin()
        self._triangle_quality()
        self._calc_locked_pts()
        if self._relax_type == 'Spring':
            self._calc_adjacent_points()

        for i in range(self._num_iter):
            self._update_for_iteration(i)
            self.logger.info(f'Iteration {i + 1}')
            self._stats['Interation'].append(i + 1)
            time.sleep(0.1)  # allow gui to show the next iteration
            self._compute_relaxed_points()
            self._check_relaxed_new_points()
            self._update_points()
            if self._converged:
                break

        self._triangle_quality(final=True)
        self._create_output_grid()

    def _build_tin(self):
        """Builds a TIN from the input co_grid."""
        self.logger.info('Creating TIN from UGrid...')
        ug = self._co_grid.ugrid
        self._tri_type = ug.cell_type_enum.TRIANGLE
        self._pts = ug.locations
        tris = []
        cnt = 0
        cs = ug.cellstream
        while cnt < len(cs):
            # cell_type = cs[cnt]
            num_nodes = cs[cnt + 1]
            start = cnt + 2
            end = start + num_nodes
            if num_nodes == 3:
                tris.extend(cs[start:end])
            else:
                self._new_cell_stream.extend(cs[cnt:end])
            cnt = end

        self._tris = np.asarray(tris)
        self._tin = Tin(self._pts, tris)
        self._tin.build_triangles_adjacent_to_points()
        self._build_adjacent_tris()
        if self._relax_type == 'Spring':
            self._interp = InterpLinear(self._pts, self._tris, self._pt_sizes)

    def _build_adjacent_tris(self):
        """Builds the adjacent triangles to each point as numpy arrays."""
        adj_tris = self._tin.triangles_adjacent_to_points
        len_adj_tris = [len(adj) for adj in adj_tris]
        cnt = 0
        adj_tri_lookup = []
        for qq in len_adj_tris:
            adj_tri_lookup.append((cnt, qq))
            cnt += qq
        self._adj_tri_lookup = np.asarray(adj_tri_lookup)
        self._adj_tris = np.asarray([t for adj in adj_tris for t in adj])

    def _calc_locked_pts(self):
        """Calculates the locked nodes."""
        if self._locked_points is None:
            self._locked_points = np.zeros(self._pts.shape[0], dtype=np.int8)

        bp = np.asarray(self._tin.boundary_points)
        self._locked_points[bp] = 1

    def _compute_centroids_and_area(self):
        """Compute the centroids and areas of the triangles."""
        self._centroids = compute_centroids_and_area_parallel(self._pts, self._tris, self._locked_points, self._lloyd)

    def _calc_adjacent_points(self):
        """Compute the adjacent points to each point."""
        npts = len(self._pts)
        adj_pts = [pts_adjacent_pt(i, self._adj_tri_lookup, self._adj_tris, self._tris) for i in range(npts)]
        adj_pts = [set(adj) - {idx} for idx, adj in enumerate(adj_pts)]
        len_adj_pts = [len(pts) for pts in adj_pts]
        cnt = 0
        adj_pt_lookup = []
        for qq in len_adj_pts:
            adj_pt_lookup.append((cnt, qq))
            cnt += qq
        self._adj_pt_lookup = np.asarray(adj_pt_lookup)
        self._adj_pts = np.asarray([p for adj in adj_pts for p in adj])

    def _compute_relaxed_points(self):
        """Compute the new locations of the relaxed points."""
        if self._relax_type in ['Area', 'Lloyd']:
            self._new_pts = area_relax_parallel(self._locked_points, self._pts, self._adj_tri_lookup,
                                                self._adj_tris, self._centroids)
        elif self._relax_type == 'Angle':
            self._new_pts = angle_relax_parallel(self._locked_points, self._pts, self._tris, self._adj_tri_lookup,
                                                 self._adj_tris)
        elif self._relax_type == 'Spring':
            self._new_pts = spring_relax_parallel(self._locked_points, self._pts, self._adj_pt_lookup,
                                                  self._adj_pts, self._pt_sizes)

    def _check_relaxed_new_points(self):
        """Check the new relaxed point locations to make sure they don't create bad triangles."""
        pts, d2 = check_relaxed_new_pts_parallel(self._locked_points, self._pts, self._new_pts, self._tris,
                                                 self._converge_dist2)
        self._new_pts = pts
        self._max_dist = np.sqrt(d2)
        if d2 == 0.0:
            self.logger.info('Early termination, unable to move pts without creating invalid triangles.')
        else:
            self.logger.info(f'Maximum distance moved: {self._max_dist * self._factor}')

    def _update_points(self):
        """Update the TIN points."""
        self._tin.points = self._new_pts
        self._pts = self._new_pts
        if self._max_dist < self._converge_dist:
            self.logger.info('Early termination, maximum distance moved is less than convergence distance')
            self._converged = True
            return

    def _update_for_iteration(self, iter):
        """Update the TIN and other variables for the next iteration.

        Args:
            iter (int): The iteration number.
        """
        if self._optimize_triangles and self._tin.optimize_triangulation():
            self._tris = np.asarray(self._tin.triangles)
            self._tin.build_triangles_adjacent_to_points()
            self._build_adjacent_tris()
            if self._relax_type == 'Spring':
                self._calc_adjacent_points()

        if self._relax_type in ['Area', 'Lloyd']:
            self._compute_centroids_and_area()
        elif self._relax_type == 'Spring' and iter > 0:
            self._update_sizes()

    def _triangle_quality(self, final=False):
        """Compute and report the triangle quality.

        Args:
            final (bool): True if this is the final quality report.
        """
        quality = compute_triangle_quality_parallel(self._pts, self._tris)
        q_strs = ['Quality Alpha Min', 'Quality L1', 'Quality ALS', 'Quality Rr', 'Quality Lr', 'Quality Lh']
        for j, ss in enumerate(q_strs):
            qq = quality[:, j]
            stats = (qq.min(), np.median(qq), qq.mean(), qq.std())
            self._stats[ss].append(stats)
        if final:
            mins = []
            medians = []
            means = []
            stds = []
            for ss in q_strs:
                mins.append(self._stats[ss][0][0])
                mins.append(self._stats[ss][-1][0])
                medians.append(self._stats[ss][0][1])
                medians.append(self._stats[ss][-1][1])
                means.append(self._stats[ss][0][2])
                means.append(self._stats[ss][-1][2])
                stds.append(self._stats[ss][0][3])
                stds.append(self._stats[ss][-1][3])

            df_dict = {
                'Quality': ['Alpha Min', 'Final Alpha Min', 'L1', 'Final L1', 'ALS', 'Final ALS', 'Rr', 'Final Rr',
                            'Lr', 'Final Lr', 'Lh', 'Final Lh'],
                'Min': mins, 'Median': medians, 'Mean': means, 'Std. Dev.': stds,
            }
            df = pd.DataFrame(df_dict)
            self.logger.info(f'Quality statistics:\n{df}')

    def _update_sizes(self):
        """Interpolate new sizes to the points given the original size function."""
        self._pt_sizes = self._interp.interpolate_to_points(self._pts)

    def _create_output_grid(self):
        """Creates the new output grid."""
        cnt = 0
        while cnt < len(self._tris):
            self._new_cell_stream.extend([self._tri_type, 3] + list(self._tris[cnt:cnt + 3]))
            cnt += 3

        ug = UGrid(self._pts, self._new_cell_stream)
        b = UGridBuilder()
        b.set_is_2d()
        b.set_ugrid(ug)
        self.out_co_grid = b.build_grid()
