"""RefineUGridByErrorTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import UGridBuilder
from xms.grid.ugrid import UGrid as XmUGrid
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool
from xms.tool_core.tool import equivalent_arguments

# 4. Local modules
from xms.tool.algorithms.mesh_2d.mesh_from_ugrid import MeshFromUGrid

ARG_INPUT_GRID = 0
ARG_INPUT_LND = 1
ARG_OUTPUT_GRID = 2

DEFAULT_TOLERANCE = 0.000001  # XM_ZERO_TOL


class RefineUGridTool(Tool):
    """Tool to refine a ugrid."""

    def __init__(self, name='Refine UGrid'):
        """Initializes the class."""
        super().__init__(name=name)
        self._file_count = 0
        self._input_cogrid = None
        self._ugrid = None
        self._lnd = None

        self._using_lnd = False
        self._locked_nodes = []
        self._updated_cellstreams = {}
        self._updated_num_cells = {}
        self._adjacent_cellstreams = {}
        self._split_edges_to_midpoint = {}

        self._force_ugrid = True
        self._geom_txt = 'UGrid'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='grid', description=f'{self._geom_txt.capitalize()}'),
            self.dataset_argument(name='lnd', description='Locked nodes dataset',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED],
                                  optional=True),
            self.string_argument(name='refined_grid', description=f'Output {self._geom_txt} name', value='refined',
                                 optional=False),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate the input grid
        self._validate_input_grid(errors, arguments[ARG_INPUT_GRID])
        self._lnd = self._validate_input_dataset(arguments[ARG_INPUT_LND], errors)
        if self._lnd is not None:
            if self._lnd.geom_uuid != self._input_cogrid.uuid:
                errors[arguments[ARG_INPUT_LND].name] = f'Selected LND is not on selected {self._geom_txt}.'
            else:
                self._using_lnd = True

        return errors

    def validate_from_history(self, arguments):
        """Called to determine if arguments are valid from history.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (bool): True if no errors, False otherwise.
        """
        # Make sure there are 6 or more arguments
        default_arguments = self.initial_arguments()
        default_length = len(default_arguments)
        if len(arguments) < default_length:
            return False
        if not equivalent_arguments(arguments[0:default_length], default_arguments):
            return False
        for i in range(len(default_arguments), len(arguments)):
            if arguments[i].io_direction != IoDirection.INPUT:
                return False
            if arguments[i].type != 'raster':
                return False
        return True

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        self._input_cogrid = self.get_input_grid(argument.text_value)
        if not self._input_cogrid.check_all_cells_2d():
            errors[argument.name] = f'{self._geom_txt.capitalize()} must be 2D.'
        else:
            self._ugrid = self._input_cogrid.ugrid

    def is_cell_unlocked(self, cell_id):
        """Called for ugrid patch.

        Args:
            ug (UGrid): the grid
            cell_id (int): ids of the cell

        Returns:
            (bool): is cell unlocked
        """
        node_ids = list(self._ugrid.get_cell_points(cell_id))
        cell_unlocked = True
        if self._using_lnd is True:
            for node_id in node_ids:
                if abs(self._locked_nodes[node_id]) > DEFAULT_TOLERANCE:
                    cell_unlocked = False
                    break
        return cell_unlocked

    def get_split_status(self, ug, id, all_edges):
        """Called for ugrid patch.

        Args:
            ug (UGrid): the grid
            id (int): cell id
            all_edges (list): list of edges in the cell

        Returns:
            (bool): True if it is completely unlocked
            (bool): True if an edge will be split
            (list): list of bools for if edge will split
        """
        split_status = []
        has_a_split = False
        all_unlocked = self.is_cell_unlocked(id)
        if all_unlocked is False:
            for edge in all_edges:
                adj_ids = ug.get_edge_adjacent_cells(edge)
                split = False
                for adj_id in adj_ids:
                    if adj_id == id:
                        continue
                    split = self.is_cell_unlocked(adj_id)
                if split is True:
                    has_a_split = True
                split_status.append(split)
            return all_unlocked, has_a_split, split_status

        return True, True, [True] * len(all_edges)

    def do_cell_split(self, ug, locations, id):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            locations (list): locations of grid points
            id (int): cell id

        """
        all_edges = ug.get_cell_edges(id)
        all_unlocked, has_a_split, split_status = self.get_split_status(ug, id, all_edges)

        if has_a_split is False:
            # nothing to do here
            return

        if len(all_edges) == 3:
            return self._do_tri_split(ug, id, all_edges, all_unlocked, split_status, locations)
        else:
            return self._do_poly_split(ug, id, all_edges, all_unlocked, split_status, locations)

    def _do_tri_split(self, ug, id, all_edges, all_unlocked, split_status, locations):
        """Split the triangle into two triangles.

        Args:
            ug (UGrid): the grid to modify.
            id (int): cell id
            all_edges (list): list of edges
            all_unlocked (bool): are all edges unlocked?
            split_status (list): list of bools if edge is split
            locations (list): locations of grid points

        """
        new_cellstream = []
        num_new_cells = 1
        if all_unlocked:
            mid_pts = []
            for edge in all_edges:
                pt1 = edge[0]
                pt2 = edge[len(edge) - 1]
                mid_pt_id = self.get_edge_midpoint(ug, pt1, pt2, locations)
                mid_pts.extend([mid_pt_id])

            for i, edge in enumerate(all_edges):
                new_edges = []

                prev_edge_index = self.advance_index(i, 3, False)
                new_edges.extend([mid_pts[i], mid_pts[prev_edge_index]])

                new_edges.extend([edge[0]])

                new_cellstream.extend([5, len(new_edges)])  # polygon, number of sides
                new_cellstream.extend(new_edges)

            # create the triangle in the middle
            new_cellstream.extend([5, 3, mid_pts[0], mid_pts[1], mid_pts[2]])
            num_new_cells = 4
        else:
            # which edge is the split one?
            split_idx = split_status.index(True)
            split_edge = all_edges[split_idx]
            new_pt = self.get_edge_midpoint(ug, split_edge[0], split_edge[1], locations)
            next_idx = self.advance_index(split_idx, 3, True)
            next_edge = all_edges[next_idx]

            # add the two new triangles to a cell stream
            new_cellstream = [5, 3, split_edge[0], new_pt, next_edge[1],
                              5, 3, new_pt, split_edge[1], next_edge[1]]
            num_new_cells = 2

        self._updated_cellstreams[id] = new_cellstream
        self._updated_num_cells[id] = num_new_cells

    def _do_poly_split(self, ug, id, all_edges, all_unlocked, split_status, locations):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            id (int): cell id
            all_edges (list): list of edges
            all_unlocked (bool): are all edges unlocked?
            split_status (list): list of bools if edge is split
            locations (list): locations of grid points

        """
        # create the centroid point
        centroid = ug.get_cell_centroid(id)[1]
        centroid = (centroid[0], centroid[1], math.fsum([locations[edge[0]][2] for edge in all_edges]) / len(all_edges))
        centroid_id = len(locations)
        locations.append(centroid)

        # go edge by edge and see if we are splitting
        new_cellstream = []
        new_num_cells = 0
        for i, edge in enumerate(all_edges):
            new_edges = []
            split = split_status[i]
            next_i = self.advance_index(i, len(all_edges), True)
            next_split = split_status[next_i]
            next_edge = all_edges[next_i]
            prev_i = self.advance_index(i, len(all_edges), False)
            prev_split = split_status[prev_i]
            if split:
                split_pt_id = self.get_edge_midpoint(ug, edge[0], edge[1], locations)
                new_edges = [centroid_id, split_pt_id, edge[1]]
                if next_split:
                    next_split_pt_id = self.get_edge_midpoint(ug, next_edge[0], next_edge[1], locations)
                    new_edges.append(next_split_pt_id)
                else:
                    new_edges.append(next_edge[1])
                new_cellstream.extend([7, len(new_edges)])
                new_cellstream.extend(new_edges)
                new_num_cells += 1
            elif i < len(all_edges) - 1 and prev_split is False:
                new_edges = [centroid_id, edge[0], edge[1]]
                if next_split:
                    next_split_pt_id = self.get_edge_midpoint(ug, next_edge[0], next_edge[1], locations)
                    new_edges.append(next_split_pt_id)
                new_cellstream.extend([7, len(new_edges)])
                new_cellstream.extend(new_edges)
                new_num_cells += 1

        self._updated_cellstreams[id] = new_cellstream
        self._updated_num_cells[id] = new_num_cells

    def split_cells(self, ug):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            raster (RasterInput): raster for elevations

        Returns:
            (UGrid): the new grid
        """
        locations = list(ug.locations)

        # split cells
        for id in range(ug.cell_count):
            self.do_cell_split(ug, locations, id)

        # create the new UGrid
        old_cell_count = ug.cell_count
        old_cellstream = np.asarray(ug.cellstream)
        new_cellstream = []
        tmp_cellstream = []

        cur_old_cell_id = 0
        stream_idx = 0

        while cur_old_cell_id < old_cell_count:
            num_cell_pts = old_cellstream[stream_idx + 1]
            next_cell_stream_idx = stream_idx + num_cell_pts + 2

            # is the cell split?
            if cur_old_cell_id in self._updated_cellstreams.keys():
                tmp_cellstream.extend(self._updated_cellstreams[cur_old_cell_id])
            else:
                # add the cell to the new cellstream
                tmp_cellstream.extend(old_cellstream[stream_idx:next_cell_stream_idx])

            cur_old_cell_id += 1
            stream_idx = next_cell_stream_idx

            if (len(tmp_cellstream)) % 5000 == 0:
                new_cellstream.extend(tmp_cellstream)
                tmp_cellstream.clear()

        # create the new ugrid
        new_cellstream.extend(tmp_cellstream)
        ug = XmUGrid(locations, new_cellstream)
        return ug

    def advance_index(self, cur_index, num_items, forward):
        """Called to determine if arguments are valid.

        Args:
            cur_index (int): index to be advanced.
            num_items (int): number of items.
            forward (bool): true = move forward, false = move back

        Returns:
            (int): new index
        """
        new_index = -1
        if forward is True:
            if cur_index == num_items - 1:
                new_index = 0
            else:
                new_index = cur_index + 1
        else:
            if cur_index == 0:
                new_index = num_items - 1
            else:
                new_index = cur_index - 1

        return new_index

    def get_edge_midpoint(self, ug, pt1, pt2, locations):
        """Split the specified cells and create a new ugrid.

        Args:
            ug (UGrid): the grid to modify.
            pt1 (int): first point on edge
            pt2 (int): second point on edge
            locations (list): locations of grid points

        Returns:
            (int): point id for the new point
        """
        # has a midpoint already been created?
        if (pt2, pt1) in self._split_edges_to_midpoint.keys():
            return self._split_edges_to_midpoint[(pt2, pt1)]

        end_pts = [pt1, pt2]
        end_locs = ug.get_points_locations(end_pts)
        xy_mid = [((end_locs[0][0] + end_locs[1][0]) / 2.0), ((end_locs[0][1] + end_locs[1][1]) / 2.0),
                  ((end_locs[1][2] + end_locs[1][2]) / 2.0)]
        locations.append(xy_mid)
        mid_pt_id = len(locations) - 1
        self._split_edges_to_midpoint[(pt1, pt2)] = mid_pt_id

        return mid_pt_id

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # get the lnd data
        if self._using_lnd:
            self._locked_nodes = list(self._lnd.values[0])

        self._ugrid = self.split_cells(self._ugrid)

        self.logger.info('Refining completed.')

        if self._force_ugrid is True:
            co_builder = UGridBuilder()
            co_builder.set_is_2d()
            co_builder.set_ugrid(self._ugrid)
            output = co_builder.build_grid()
        else:
            convert = MeshFromUGrid()
            output, _ = convert.convert(source_opt=convert.SOURCE_OPT_POINTS, input_ugrid=self._ugrid,
                                        logger=self.logger, tris_only=False, split_collinear=True)

        self.set_output_grid(output, arguments[ARG_OUTPUT_GRID], None, force_ugrid=self._force_ugrid)
