"""Mesh2dFromUGrid2dTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_builder import UGridBuilder
from xms.grid.geometry.geometry import on_line_2d
from xms.grid.triangulate import Tin
from xms.grid.ugrid import UGrid
import xms.mesher
from xms.mesher import meshing

# 4. Local modules


class MeshFromUGrid():
    """Tool to convert 2D UGrids to SMS 2D mesh module geometries."""
    SOURCE_OPT_CENTROIDS = 0
    SOURCE_OPT_POINTS = 1

    def __init__(self):
        """Initializes the class."""
        self._source_opt = self.SOURCE_OPT_CENTROIDS
        self._input_cogrid = None
        self._input_ugrid = None
        self._output_mesh = None
        self._output_mesh_uuid = ''
        self._new_cellstream = None
        self._new_cell_idx = None
        self._split_collinear = False
        self._triangles_only = True

    def _triangulate_output(self, points):
        """Triangulate the output mesh.

        Args:
            points (list[tuple[int]]: The x,y,z point location tuples

        Returns:
            list[int]: The output mesh triangles
        """
        tin = Tin(points=points)
        tin.triangulate()
        tin_tris = tin.triangles
        triangles = [  # Unstack the triangle array for convenience
            (tin_tris[idx], tin_tris[idx + 1], tin_tris[idx + 2]) for idx in range(0, len(tin_tris), 3)
        ]
        cellstream = []
        for tri in triangles:  # 5 = XMU_TRIANGLE
            cellstream.extend([5, 3, tri[0], tri[1], tri[2]])
        return cellstream

    def _build_output_cogrid(self, points, cellstream):
        """Build the output CoGrid (Mesh2D).

        Args:
            points (list[tuple[int]]): The x,y,z point location tuples
            cellstream (list[int]): The output mesh triangles
        """
        self.logger.info('Building 2D Mesh...')
        ug = UGrid(points, cellstream)
        co_builder = UGridBuilder()
        co_builder.set_is_2d()
        co_builder.set_ugrid(ug)
        self._output_mesh = co_builder.build_grid()
        self._output_mesh_uuid = str(uuid.uuid4())
        self._output_mesh.uuid = self._output_mesh_uuid

    def _has_collinear_pts(self, locs):
        """Determines if cell (with exactly four points) should be split due to collinear points.

        Args:
            locs (list[tuple[int]]): The x,y,z point location tuples
        """
        has_collinear = False
        if on_line_2d(locs[0], locs[2], locs[1]):
            has_collinear = True
        elif on_line_2d(locs[1], locs[3], locs[2]):
            has_collinear = True
        elif on_line_2d(locs[2], locs[0], locs[3]):
            has_collinear = True
        elif on_line_2d(locs[3], locs[1], locs[0]):
            has_collinear = True

        return has_collinear

    def _append_to_cellstream(self, cell, old_cell_idx):
        """Appends a cell to the new cellstream.

        Args:
            cell (list): list defining the cell
            old_cell_idx (int): index to the original cell
        """
        self._new_cellstream.extend(cell)
        cell_idx = len(self._new_cell_idx)
        self._new_cell_idx.append((cell_idx, old_cell_idx))

    def _create_ugrid(self):
        """Creates the output ugrid."""
        cellstream = None
        if self._source_opt == self.SOURCE_OPT_POINTS:
            self.logger.info('Creating 2D Mesh from input UGrid points...')
            points = self._input_ugrid.locations
            tri = self._input_ugrid.cell_type_enum.TRIANGLE
            quad = self._input_ugrid.cell_type_enum.QUAD

            self._new_cellstream = []
            self._new_cell_idx = []
            for i in range(self._input_ugrid.cell_count):
                pts = list(self._input_ugrid.get_cell_points(i))

                do_cell_split = False
                if len(pts) > 3:
                    locs = [points[pt] for pt in pts]
                    if self._triangles_only or len(pts) > 4:
                        do_cell_split = True
                    elif self._split_collinear:
                        do_cell_split = self._has_collinear_pts(locs)

                if do_cell_split is True:
                    pt_lookup = {(points[pt][0], points[pt][1]): pt for pt in pts}
                    locs.reverse()  # poly mesher wants clockwise polygons
                    poly_input = meshing.PolyInput(outside_polygon=locs)
                    poly_input.generate_interior_points = False
                    small_ug = xms.mesher.generate_mesh(polygon_inputs=[poly_input])
                    small_ug_locs = small_ug.locations
                    small_ug_pt_idx = [pt_lookup[(pp[0], pp[1])] for pp in small_ug_locs]
                    for cell_idx in range(small_ug.cell_count):
                        small_cell_pts = small_ug.get_cell_points(cell_idx)
                        self._append_to_cellstream([tri, 3] + [small_ug_pt_idx[small_cell_pts[0]],
                                                               small_ug_pt_idx[small_cell_pts[1]],
                                                               small_ug_pt_idx[small_cell_pts[2]]], i)
                elif len(pts) < 4:
                    self._append_to_cellstream([tri, 3] + pts, i)
                elif len(pts) < 5:
                    self._append_to_cellstream([quad, 4] + pts, i)
            cellstream = self._new_cellstream

        else:  # self._source_opt == self.SOURCE_OPT_CELL_CENTERS
            self.logger.info('Creating 2D Mesh from input UGrid cell centers...')
            points = [list(self._input_ugrid.get_cell_centroid(i)[1]) for i in range(self._input_ugrid.cell_count)]
            cell_elevations = self._input_cogrid.cell_elevations
            if cell_elevations is not None:
                self.logger.info('Assigning cell elevations to output points...')
                for pt, elev in zip(points, cell_elevations):
                    pt[2] = elev
        self.logger.info('Triangulating source points...')
        if cellstream is None:
            cellstream = self._triangulate_output(points)
        self.logger.info('Building 2D Mesh...')
        self._build_output_cogrid(points, cellstream)
        self.logger.info('UGrid successfully created.')

    def convert(self, source_opt, input_ugrid, logger, input_cogrid=None, tris_only=False, split_collinear=False):
        """Override to run the tool.

        Args:
            source_opt (int): Source option.
            input_ugrid (UGrid): Ugrid to convert
            logger (logger): Logger for logging.
            tris_only (bool): If True, the result should only have triangles
            split_collinear (bool): If True, split collinear points

        Returns:
            output_mesh (UGrid2d): The converted mesh.
        """
        self._source_opt = source_opt
        self._input_ugrid = input_ugrid

        if input_cogrid is not None:
            self._input_cogrid = input_cogrid
        else:
            co_builder = UGridBuilder()
            co_builder.set_is_2d()
            co_builder.set_ugrid(self._input_ugrid)
            self._input_cogrid = co_builder.build_grid()

        self._triangles_only = tris_only
        self._split_collinear = split_collinear
        self.logger = logger

        self._create_ugrid()

        return self._output_mesh, self._new_cell_idx
