"""GeometryGradient class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import logging
import math

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import Grid
from xms.constraint.ugrid_activity import CellToPointActivityCalculator
from xms.datasets import vectors as xmdv
from xms.datasets.dataset_reader import DatasetReader
from xms.datasets.dataset_writer import DatasetWriter

# 4. Local modules
from xms.tool.algorithms.mesh_2d.mesh_from_ugrid import MeshFromUGrid
from xms.tool.utilities.dataset_tool import set_builder_activity_flags


DEFAULT_TOLERANCE = 0.000001  # XM_ZERO_TOL


class GeometryGradient():
    """Algorithm to compute geometry gradient datasets."""
    def __init__(self, gradient_vec_bool: bool, gradient_mag_bool: bool, gradient_dir_bool: bool,
                 vector_dataset_name: str, magnitude_dataset_name: str, direction_dataset_name: str,
                 input_dataset_grid: Grid, dataset_reader: DatasetReader, logger: logging.Logger):
        """
        Initializes the class.

        Args:
            gradient_vec_bool: Whether or not to generate the vector dataset.
            gradient_mag_bool: Whether or not to generate the magnitude dataset.
            gradient_dir_bool: Whether or not to generate the direction dataset.
            vector_dataset_name: The name of the output vector dataset.
            magnitude_dataset_name: The name of the output magnitude dataset.
            direction_dataset_name: The name of the output direction dataset.
            input_dataset_grid: The input Grid.
            dataset_reader: The DatasetReader for the input dataset.
            logger: The logger used to log output to the user.
        """
        self._dataset_reader = None
        self._gradient_builder = None
        self._gradient_mag_builder = None
        self._gradient_dir_builder = None
        self._dset_grid = None
        self._cell_plane_vals = []

        # Tin will be created from input dataset's geometry
        self._points = None
        # Output dataset values filled time step at a time
        self._ts_data = None
        self._vx = []
        self._vy = []

        self.gradient_vec_bool = gradient_vec_bool
        self.gradient_mag_bool = gradient_mag_bool
        self.gradient_dir_bool = gradient_dir_bool
        self.vector_dataset_name = vector_dataset_name
        self.magnitude_dataset_name = magnitude_dataset_name
        self.direction_dataset_name = direction_dataset_name
        self.input_dataset_grid = input_dataset_grid
        self._dataset_reader = dataset_reader
        self.logger = logger

    def geometry_gradient(self) -> tuple[DatasetWriter | None, DatasetWriter | None, DatasetWriter | None]:
        """
        Runs the geometry gradient algorithm which computes geometry gradient datasets.

        Returns:
            A tuple where the entries are DatasetWriter for vector dataset, DatasetWriter for magnitude dataset, and
            DatasetWriter for direction dataset.
        """
        self._triangulate_grid()

        # Extract the component, time step at a time
        self._setup_output_dataset_builders()

        # Setup activity calculator
        if self._dataset_reader.activity and self._dataset_reader.values:
            if self._dataset_reader.values.shape != self._dataset_reader.activity.shape:
                # Nodal dataset values with cell activity
                self._dataset_reader.activity_calculator = CellToPointActivityCalculator(self._dset_grid)

        time_count = len(self._dataset_reader.times)
        num_points = len(self._points)
        for tsidx in range(time_count):
            self.logger.info(f'Processing time step {tsidx + 1} of {time_count}...')
            self._ts_data, activity = self._dataset_reader.timestep_with_activity(tsidx)
            set_builder_activity_flags(activity, [self._gradient_builder, self._gradient_mag_builder,
                                                  self._gradient_dir_builder])

            self._vx = np.array([0.0] * num_points)
            self._vy = np.array([0.0] * num_points)

            self._cell_plane_vals = [None] * self._dset_grid.cell_count
            for point_idx in range(num_points):  # Loop through each dataset location
                self._gradient_at_point(point_idx)

            if self._gradient_builder is not None:
                vector_data = np.vstack((self._vx, self._vy)).T
                self._gradient_builder.append_timestep(self._dataset_reader.times[tsidx], vector_data, None)
            if self._gradient_mag_builder is not None:
                mags = xmdv.vx_vy_to_magnitude(vx=self._vx, vy=self._vy)
                self._gradient_mag_builder.append_timestep(self._dataset_reader.times[tsidx], mags, None)
            if self._gradient_dir_builder is not None:
                angles = xmdv.vx_vy_to_direction(self._vx, self._vy)
                self._gradient_dir_builder.append_timestep(self._dataset_reader.times[tsidx], angles, None)

        self._add_output_datasets()

        return self._gradient_builder, self._gradient_mag_builder, self._gradient_dir_builder

    def _triangulate_grid(self):
        """Retrieve the input dataset's geometry and triangulate it.
        """
        # Get the input dataset's geometry from XMS.
        cogrid = self.input_dataset_grid
        if cogrid:
            self._dset_grid = cogrid.ugrid
        if self._dset_grid is None:
            raise RuntimeError("Unable to read dataset input's geometry.")

        # preserve triangles and ear cut anything that is not a triangle
        tri_type = self._dset_grid.cell_type_enum.TRIANGLE
        if not cogrid.check_all_cells_are_of_type(tri_type):
            convert = MeshFromUGrid()
            mesh, _ = convert.convert(source_opt=convert.SOURCE_OPT_POINTS, input_ugrid=self._dset_grid,
                                      logger=self.logger, tris_only=True, split_collinear=True)
            self._dset_grid = mesh.ugrid

        self._points = self._dset_grid.locations

    def _gradient_at_point(self, point_idx):
        """Compute the gradient vector components at a point.

        Args:
            point_idx (int): Index of the point to compute gradient vector for
        """
        adjacent_tris = self._dset_grid.get_point_adjacent_cells(point_idx)
        sum_a = 0.0
        sum_b = 0.0
        sum_c = 0.0
        count = 0
        for adjacent_tri in adjacent_tris:
            if self._cell_plane_vals[adjacent_tri] is None:
                pts = self._dset_grid.get_cell_points(adjacent_tri)
                tri_points = [
                    self._points[pts[0]].copy(),
                    self._points[pts[1]].copy(),
                    self._points[pts[2]].copy()
                ]
                # Override Z-value with scalar value for this location at this timestep.
                tri_points[0][2] = self._ts_data[pts[0]]
                tri_points[1][2] = self._ts_data[pts[1]]
                tri_points[2][2] = self._ts_data[pts[2]]
                coeff_a, coeff_b, coeff_c = self._calculate_plane_coefficients(tri_points)
                self._cell_plane_vals[adjacent_tri] = (coeff_a, coeff_b, coeff_c)
            else:
                coeff_a, coeff_b, coeff_c = self._cell_plane_vals[adjacent_tri]
            if coeff_c != 0.0:
                sum_a += coeff_a
                sum_b += coeff_b
                sum_c += coeff_c
                count += 1

        if count == 0 or sum_c == 0.0:
            self._vx[point_idx] = 0.0
            self._vy[point_idx] = 0.0
            return

        # find average normal
        a_prime = sum_a / count
        b_prime = sum_b / count
        c_prime = sum_c / count

        # find directional derivative
        self._vx[point_idx] = -a_prime / c_prime
        self._vy[point_idx] = -b_prime / c_prime

    def _calculate_plane_coefficients(self, tri_points):
        """Compute the planar coefficients of a triangle.

        Args:
            tri_points (Sequence[Sequence[float]]): 2D array of the triangle point locations e.g.:
                [[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]]

        Returns:
            tuple(float): The first three planar coefficients
        """
        x1 = tri_points[0][0]
        x2 = tri_points[1][0]
        x3 = tri_points[2][0]
        y1 = tri_points[0][1]
        y2 = tri_points[1][1]
        y3 = tri_points[2][1]
        z1 = tri_points[0][2]
        z2 = tri_points[1][2]
        z3 = tri_points[2][2]
        dx12 = x1 - x2
        dy12 = y1 - y2
        dz12 = z1 - z2
        dx23 = x2 - x3
        dy23 = y2 - y3
        dz23 = z2 - z3
        dx31 = x3 - x1
        dy31 = y3 - y1
        dz31 = z3 - z1
        coeff_a = (y1 * dz23 + y2 * dz31 + y3 * dz12)
        coeff_b = (z1 * dx23 + z2 * dx31 + z3 * dx12)
        coeff_c = (x1 * dy23 + x2 * dy31 + x3 * dy12)
        if coeff_c < DEFAULT_TOLERANCE:  # normalize if c is small
            mag = math.sqrt(coeff_a**2 + coeff_b**2 + coeff_c**2)
            if mag < DEFAULT_TOLERANCE:  # Compute a tighter tolerance
                tol_scale = max(
                    abs(dx23), abs(dx31), abs(dx12),
                    abs(dy23), abs(dy31), abs(dy12),
                    abs(dz23), abs(dz31), abs(dz12)
                )
                if mag < tol_scale * DEFAULT_TOLERANCE:
                    return 0.0, 0.0, 0.0
            coeff_a /= mag
            coeff_b /= mag
            coeff_c /= mag
        return coeff_a, coeff_b, coeff_c

    def _setup_output_dataset_builders(self):
        """Set up dataset builders for selected tool outputs."""
        # Create a place for the output dataset file
        if self.gradient_vec_bool:  # gradient_vec_bool
            dataset_name = self.vector_dataset_name
            self._gradient_builder = DatasetWriter(
                name=dataset_name,
                geom_uuid=self._dataset_reader.geom_uuid,
                num_components=2,
                ref_time=self._dataset_reader.ref_time,
                time_units=self._dataset_reader.time_units,
                null_value=self._dataset_reader.null_value,
            )
        if self.gradient_mag_bool:  # gradient_mag_bool
            dataset_name = self.magnitude_dataset_name
            self._gradient_mag_builder = DatasetWriter(
                name=dataset_name,
                geom_uuid=self._dataset_reader.geom_uuid,
                num_components=1,
                ref_time=self._dataset_reader.ref_time,
                time_units=self._dataset_reader.time_units,
                null_value=self._dataset_reader.null_value,
            )
        if self.gradient_dir_bool:  # gradient_dir_bool
            dataset_name = self.direction_dataset_name
            self._gradient_dir_builder = DatasetWriter(
                name=dataset_name,
                geom_uuid=self._dataset_reader.geom_uuid,
                num_components=1,
                ref_time=self._dataset_reader.ref_time,
                time_units=self._dataset_reader.time_units,
                null_value=self._dataset_reader.null_value,
            )

    def _add_output_datasets(self):
        """Add datasets created by the algorithm to be sent back to XMS."""
        if self._gradient_builder is not None:
            self.logger.info('Writing output gradient dataset to XMDF file...')
            self._gradient_builder.appending_finished()
        if self._gradient_mag_builder is not None:
            self.logger.info('Writing output gradient magnitude dataset to XMDF file...')
            self._gradient_mag_builder.appending_finished()
        if self._gradient_dir_builder is not None:
            self.logger.info('Writing output gradient direction dataset to XMDF file...')
            self._gradient_dir_builder.appending_finished()
