"""MeshQualityReporter class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from enum import IntEnum
import logging
import math

# 2. Third party modules
from sortedcontainers import SortedDict

# 3. Aquaveo modules
from xms.constraint.contour.contour_sorter import ContourSorter
from xms.constraint.contour.contourer import Contourer
from xms.grid.geometry import geometry
from xms.grid.ugrid import UGrid

# 4. Local modules
from xms.srh.file_io.report import plots
from xms.srh.file_io.report import report_util
from xms.srh.file_io.report.point import Point


class QualityMeasure(IntEnum):
    """The different ARR plot quality measures."""
    Q_alpha_min = 0
    Q_Ll = 1
    Q_ALS = 2
    Q_Rr = 3
    Q_Lr = 4
    Q_Lh = 5
    Q_all = 6  # Not a quality measure. This one means use all measures.


class MeshQualityReporter:
    """Generates a mesh quality report and the ARR triangle plot."""
    def __init__(self):
        """Initializes the class."""
        self._ugrid = None  # The xms.grid.ugrid.ugrid.UGrid
        self._ugrid_locs = []  # The xyz point locations in the UGrid. May get transformed.

        self._max_quality_in_plot = 100.0  # 100.0 is just a high initial number. Max will end up lower.
        self._max_plot_pts = 100  # Maximum number of points in the plot
        self._triangle = {'x': [0.0, 50.0, 50.0, 0.0], 'y': [0.0, 0.0, 28.87, 0.0]}  # Main triangle xy polygon
        self._sqrt3 = math.sqrt(3.0)  # Square root of 3

    def get_html_plot_files(self, ugrid, projection, methods, mesh_name, report_dir):  # pragma: no cover
        """Returns list of 6 html filepaths containing the ARR plots.

        Files are in this order: ['Q(alpha_min)', 'Q(Ll)', 'Q(ALS)', 'Q(Rr)', 'Q(Lr)', 'Q(Lh)']
        See mqARRPlotHelper::CreateARRPlotMesh.

        Args:
            ugrid (:obj:`xms.grid.ugrid.ugrid.UGrid`): The UGrid.
            projection (:obj:`xms.data_objects.parameters.Projection`): The UGrid's projection.
            methods (:obj:`list[QualityMeasure]`): List of the plot types to include.
            mesh_name (:obj:`str`): Name of the mesh (used to create the filename).
            report_dir (:obj:`str`): Path to directory where report files are created.

        Returns:
            filepaths (:obj:`list[str]`): Filepaths to html file containing the plot.
        """
        rv, plot_points = self.compute_plot_points(ugrid, projection)
        if not rv:
            return []

        # Put contour lines into a dict by method
        contour_dict = {}
        for method in methods:
            contour_dict[method] = self.get_contours(method)

        return plots.plot_arr_triangle(plot_points, self._triangle, contour_dict, methods, mesh_name, report_dir)

    def compute_plot_points(self, ugrid, projection):
        """Computes the points in the ARR triangle.

        See mqARRPlotHelper::CreateARRPlotMesh.

        Args:
            ugrid (:obj:`xms.grid.ugrid.ugrid.UGrid`): The UGrid.
            projection (:obj:`xms.data_objects.parameters.Projection`): The UGrid's projection.

        Returns:
            (:obj:`tuple`): tuple containing:

                rv (:obj:`bool`): Return value (True or False).

                The plot data points as a list of xy pairs.
        """
        if not ugrid or ugrid.locations is None or len(ugrid.locations) == 0:  # pragma: no cover
            raise RuntimeError('No UGrid to compute the mesh quality from.')

        self._ugrid = ugrid

        # Transform the points if in geographic
        if projection and projection.coordinate_system == 'GEOGRAPHIC':
            geographic_wkt = projection.well_known_text
            utm_wkt = report_util.utm_wkt_from_geographic(ugrid.locations[0], geographic_wkt)
            rv, self._ugrid_locs = report_util.transform_points(ugrid.locations, geographic_wkt, utm_wkt)
            if not rv:  # pragma: no cover
                logger = logging.getLogger("Summary Report")
                logger.warning('Could not get points in geographic coordinates.')
                return False, None
        else:
            self._ugrid_locs = self._ugrid.locations

        # Loop through the elements filling in the plot data.
        # Quadratic elements ignore the midsides.
        plot_pts = SortedDict()  # Dict of double -> PlotPt
        cell_count = self._ugrid.cell_count
        for cell_idx in range(cell_count):
            cell_type = self._ugrid.get_cell_type(cell_idx)
            if cell_type == UGrid.cell_type_enum.TRIANGLE:
                self._do_triangle(cell_idx, cell_type, plot_pts)
            elif cell_type == UGrid.cell_type_enum.QUAD:
                self._do_quad(cell_idx, cell_type, plot_pts)

        # loop through the map adding the points to the plot
        plot_data_points = []  # x,y coordinates of points inside the triangle
        for point in plot_pts.values():
            # store the x,y value for the element
            plot_data_points.append(self._compute_element_arr_loc(point.q1, point.q2))
        return True, plot_data_points

    def get_contours(self, method):
        """Returns a list of contour line information.

        See mqEPlotARRPlot::DrawAnnotationContours

        Args:
            method (:obj:`QualityMeasure`): The quality measure to use: [Q(alpha_min), Q(Ll),
            Q(ALS), Q(Rr), Q(Lr), Q(Lh)]

        Returns:
           contours (:obj:`List[dict]`): x: list of x coordinates of line, y: list of y coordinates of line.
        """
        ndiv = 10
        pll = Point()
        plr = Point()
        pul = Point()
        pur = Point()
        p = [Point(x, y) for x, y in zip(self._triangle['x'], self._triangle['y'])]
        contour_values = [0.3, 0.45, 0.6]
        nodata_value = -999.0
        tolerance = 1e-5
        contourer = Contourer(contour_values, nodata_value, tolerance)
        for row in range(1, ndiv + 1):
            # compute the two y locations
            pcnt_btm = (row - 1) / ndiv
            pcnt_top = row / ndiv
            pll.y = plr.y = (1.0 - pcnt_btm) * p[0].y + pcnt_btm * p[2].y
            pul.y = pur.y = (1.0 - pcnt_top) * p[0].y + pcnt_top * p[2].y
            for col in range(1, ndiv + 1):
                # compute the four x locations
                pcnt_lft = (col - 1) / ndiv
                pcnt_rgt = col / ndiv
                pll.x = (1.0 - pcnt_btm) * ((1.0 - pcnt_lft) * p[0].x + pcnt_lft * p[1].x) + pcnt_btm * p[2].x
                plr.x = (1.0 - pcnt_btm) * ((1.0 - pcnt_rgt) * p[0].x + pcnt_rgt * p[1].x) + pcnt_btm * p[2].x
                pul.x = (1.0 - pcnt_top) * ((1.0 - pcnt_lft) * p[0].x + pcnt_lft * p[1].x) + pcnt_top * p[2].x
                pur.x = (1.0 - pcnt_top) * ((1.0 - pcnt_rgt) * p[0].x + pcnt_rgt * p[1].x) + pcnt_top * p[2].x
                # compute the four s (quality measure) values
                sll = self._compute_arr_value(pll.x, pll.y, method)
                slr = self._compute_arr_value(plr.x, plr.y, method)
                sul = self._compute_arr_value(pul.x, pul.y, method)
                sur = self._compute_arr_value(pur.x, pur.y, method)
                # lower left tri
                contourer.contour_3_points([pll, plr, pul], [sll, slr, sul])
                # upper right tri
                contourer.contour_3_points([plr, pur, pul], [slr, sur, sul])

        # Sort segments into continuous lines
        sorter = ContourSorter()
        contours = []
        segment_dict = contourer.get_contour_segments()
        for segments in segment_dict.values():
            ok, lines = sorter.sort_segments(segments, tolerance)
            if ok:
                # lines = list of lines (line = list of points) but there'll always be only 1 line
                # since nothing is disjoint.
                contours.append({'x': [pt[0] for pt in lines[0]], 'y': [pt[1] for pt in lines[0]]})
        return contours

    @staticmethod
    def _compute_arr_value(x, y, method):
        """Computes the quality measure of a point.

        See mqEPlotARRPlot::ComputeARRValue.

        Args:
            x: X coordinate of point in triangle.
            y: y coordinate of point in triangle.
            method (:obj:`QualityMeasure`): The quality measure to use:
                [Q(alpha_min), Q(Ll), Q(ALS), Q(Rr), Q(Lr), Q(Lh)]

        Returns:
            s: method == 0 -> A list of all measures. Otherwise the measure requested.
        """
        # compute the quality measures for this point
        sqrt3 = math.sqrt(3.0)
        # convert the point location to barycentric coordinates
        gamma = y / 86.60254
        beta = (x - gamma * 50) / 100.0
        alpha = 1.0 - beta - gamma
        # convert the angles from barycentric to radians
        alpha *= math.pi
        beta *= math.pi
        gamma *= math.pi
        sin_alpha = math.sin(alpha)
        sin_beta = math.sin(beta)
        sin_gamma = math.sin(gamma)
        # scaled lengths of the triangle edges (assume l_alpha = 1)
        l_beta = sin_beta / sin_alpha
        l_gamma = sin_gamma / sin_alpha
        # scaled minimum height of the triangle
        h_min = l_gamma * sin_beta
        # scaled area
        a = h_min / 2
        # outer radius
        # ro = 0.5 / sin_alpha
        # inner radius
        # r = a / (0.5 * (1 + l_beta + l_gamma))
        # compute the quality measures for this element
        # q_alpha_min = 3.0 * gamma / math.pi
        # q_ll = l_gamma
        q_als = 4.0 * sqrt3 * a / (1 + l_beta * l_beta + l_gamma * l_gamma)
        # q_rr = 2.0 * r / ro
        # q_lr = 2.0 * sqrt3 * r
        # q_lh = 2.0 * h_min / sqrt3
        s = q_als
        # if method == QualityMeasure.Q_alpha_min:
        #     s = q_alpha_min
        # elif method == QualityMeasure.Q_Ll:
        #     s = q_ll
        # elif method == QualityMeasure.Q_ALS:
        #     s = q_als
        # elif method == QualityMeasure.Q_Rr:
        #     s = q_rr
        # elif method == QualityMeasure.Q_Lr:
        #     s = q_lr
        # elif method == QualityMeasure.Q_Lh:
        #     s = q_lh
        # elif method == QualityMeasure.Q_all:
        #     s = [q_alpha_min, q_ll, q_als, q_rr, q_lr, q_lh]

        return s

    def _compute_tri_quality(self, alpha, beta, gamma):
        """Compute the triangle quality measure.

        Args:
            alpha (float): Largest interior angle in the triangle.
            beta (float): Second largest interior angle in the triangle.
            gamma (float): Smallest interior angle in the triangle.

        Returns:
            (float): The quality measure.
        """
        sin_alpha = math.sin(alpha)
        sin_beta = math.sin(beta)
        sin_gamma = math.sin(gamma)
        # scaled lengths of the triangle edges (assume l_alpha = 1)
        l_beta = sin_beta / sin_alpha
        l_gamma = sin_gamma / sin_alpha
        # scaled minimum height of the triangle
        h_min = l_gamma * sin_beta
        # scaled area
        area = h_min / 2
        # compute the quality measures for this element
        q_als = 4.0 * self._sqrt3 * area / (1 + l_beta * l_beta + l_gamma * l_gamma)
        return q_als

    def _do_triangle(self, cell_idx, cell_type, plot_pts):
        """For a triangle cell, computes the quality measures and stores it as a PlotPt.

        Args:
            cell_idx (:obj:`int`): The cell index
            cell_type: Type of cell
            plot_pts (:obj:`dict{double : PlotPt}`): Plot points
        """
        angles = self._get_tri_angles(cell_idx, cell_type)
        alpha = max(angles)
        gamma = min(angles)
        beta = math.pi - alpha - gamma
        # quality value
        quality = self._compute_tri_quality(alpha, beta, gamma)
        # See if this gamma deserves to be in the plot
        if quality < self._max_quality_in_plot or len(plot_pts) < self._max_plot_pts:
            # Add the point to the map
            q1 = alpha / math.pi
            q2 = beta / math.pi
            pt = PlotPt(cell_idx, q1, q2)
            plot_pts[quality] = pt
            # Update max_quality_in_plot
            self._max_quality_in_plot = plot_pts.peekitem(-1)[0]
        self._trim_points_to_max(self._max_plot_pts, plot_pts)

    def _do_quad(self, cell_idx, cell_type, plot_pts):
        """Computes the plot point.

        Args:
            cell_idx (:obj:`int`): The cell index.
            cell_type: Type of cell
            plot_pts (:obj:`dict{double : PlotPt}`): Plot points
        """
        quad_quality = _QuadQuality(cell_idx, self._ugrid, self._ugrid_locs)
        quality = quad_quality.compute_quality()
        # See if this delta deserves to be in the plot
        if quality < self._max_quality_in_plot or len(plot_pts) < self._max_plot_pts:
            x, y = quad_quality.compute_plot_point()
            pt = PlotPt(cell_idx, x, y)
            plot_pts[quality] = pt
            # Update max_quality_in_plot
            self._max_quality_in_plot = plot_pts.peekitem(-1)[0]
        self._trim_points_to_max(self._max_plot_pts, plot_pts)

    @staticmethod
    def _compute_element_arr_loc(alpha, beta):
        """Stores an element point in the plot data from angle measures.

        See mqARRPlotHelper::ComputeElementARRLoc.
        The angles a_ang1 and a_ang2 should be between 0-1, that is, Rad/PI.
        The max & min of the ARR are constant (0,0) to (50, 28.87)

        Args:
            alpha: is the largest interior angle in the tri
            beta: is the 2nd largest interior angle in the tri
        """
        # compute the third angle and sort the angles (sort is done already)
        gamma = 1.0 - alpha - beta
        # map the angles into the ARR subtri from 0,0 to 50, 28.87
        # for this triangle, x = alpha*0.0 + beta*100.0 + gamma*50.0
        #                    y = alpha*0.0 + beta*  0.0 + gamma*86.60254;
        x = beta * 100.0 + gamma * 50.0
        y = gamma * 86.60254
        return x, y

    @staticmethod
    def _trim_points_to_max(max_plot_pts, plot_pts):
        """Keep only self._max_plot_pts number of points.

        Args:
            max_plot_pts (:obj:`int`): Maximum number of points to include in the plot.
            plot_pts (:obj:`dict{double : PlotPt}`): Plot points.
        """
        # See if the plot is over full (>100 pts in the map)
        if len(plot_pts) > max_plot_pts:
            # Get rid of the point with the highest gamma
            plot_pts.popitem(-1)

    def _get_tri_angles(self, cell_idx, cell_type):
        """Returns the three angles of the cell as a tuple.

        Args:
            cell_idx (:obj:`int`): Cell index.
            cell_type: Type of cell

        Returns:
            See description.
        """
        ugrid_xyz = self._ugrid_locs  # shorter so things fit on one line below
        points = self._ugrid.get_cell_points(cell_idx)
        if cell_type == UGrid.cell_type_enum.TRIANGLE:
            angle1 = geometry.angle_between_edges_2d(ugrid_xyz[points[2]], ugrid_xyz[points[1]], ugrid_xyz[points[0]])
            angle2 = geometry.angle_between_edges_2d(ugrid_xyz[points[0]], ugrid_xyz[points[2]], ugrid_xyz[points[1]])
            angle3 = geometry.angle_between_edges_2d(ugrid_xyz[points[1]], ugrid_xyz[points[0]], ugrid_xyz[points[2]])
            return angle1, angle2, angle3
        else:  # pragma: no cover
            return ()


class _QuadQuality:
    """Computes the quality of a quad cell."""
    def __init__(self, cell_idx, ugrid, ug_locs):
        """Initializes the class.

        Args:
            cell_idx (int): Cell index.
            ugrid (UGrid): The UGrid.
            ug_locs (list of Point): The UGrid locations.
        """
        self._cell_idx = cell_idx
        self._pts = ugrid.get_cell_points(cell_idx)
        self._p0 = ug_locs[self._pts[0]]
        self._p1 = ug_locs[self._pts[1]]
        self._p2 = ug_locs[self._pts[2]]
        self._p3 = ug_locs[self._pts[3]]
        self._angles = []
        self._alpha = 0.0
        self._beta = 0.0
        self._gamma = 0.0
        self._delta = 0.0
        self._max_len = 0.0
        self._lmin_ratio = 0.0
        self._h_min_ratio = 0.0

    def _compute_angles(self):
        """Computes the angles between cell edges."""
        self._angles = []
        self._angles.append(geometry.angle_between_edges_2d(self._p1, self._p0, self._p3))
        self._angles.append(geometry.angle_between_edges_2d(self._p2, self._p1, self._p0))
        self._angles.append(geometry.angle_between_edges_2d(self._p3, self._p2, self._p1))
        self._angles.append(geometry.angle_between_edges_2d(self._p0, self._p3, self._p2))
        sorted_angles = list(self._angles)
        sorted_angles.sort()
        self._alpha = sorted_angles[3]
        self._beta = sorted_angles[2]
        self._gamma = sorted_angles[1]
        self._delta = sorted_angles[0]

    def _compute_edge_length_ratio(self):
        """Computes the edge length ratios."""
        # compute the edge lengths and minimum ratio
        locs = [self._p0, self._p1, self._p2, self._p3, self._p0]
        edge_lengths = []
        for i in range(1, len(locs)):
            dx = locs[i][0] - locs[i - 1][0]
            dy = locs[i][1] - locs[i - 1][1]
            edge_lengths.append(math.sqrt(dx * dx + dy * dy))
        self._max_len = max(edge_lengths)
        min_len = min(edge_lengths)
        lmin_ratio = min_len / self._max_len
        if lmin_ratio > 0.2:
            lmin_ratio = 0.95 + 0.05 * (lmin_ratio - 0.2) / 0.8
        elif lmin_ratio > 0.05:
            lmin_ratio = 0.3 + 0.65 * (lmin_ratio - 0.05) / 0.15
        else:
            lmin_ratio = 0.3 * lmin_ratio / 0.15
        self._lmin_ratio = lmin_ratio

    def _compute_heights_and_ratio(self):
        """Computes the heights and minimum ratio."""
        h0 = min(
            geometry.distance_to_line_2d(self._p2, self._p3, self._p0),
            geometry.distance_to_line_2d(self._p2, self._p3, self._p1)
        )
        h1 = min(
            geometry.distance_to_line_2d(self._p3, self._p0, self._p1),
            geometry.distance_to_line_2d(self._p3, self._p0, self._p2)
        )
        h2 = min(
            geometry.distance_to_line_2d(self._p0, self._p1, self._p2),
            geometry.distance_to_line_2d(self._p0, self._p1, self._p3)
        )
        h3 = min(
            geometry.distance_to_line_2d(self._p1, self._p2, self._p3),
            geometry.distance_to_line_2d(self._p1, self._p2, self._p0)
        )
        h_min = min(h0, h1, h2, h3)
        h_min_ratio = h_min / self._max_len
        if h_min_ratio > 0.2:
            h_min_ratio = 0.95 + 0.05 * (h_min_ratio - 0.2) / 0.8
        elif h_min_ratio > 0.05:
            h_min_ratio = 0.3 + 0.65 * (h_min_ratio - 0.05) / 0.15
        else:
            h_min_ratio = 0.3 * h_min_ratio / 0.15
        self._h_min_ratio = h_min_ratio

    def compute_quality(self):
        """Computes the quality of the quad cell.

        Returns:
            (float): The quality of the quad cell.
        """
        self._compute_angles()
        self._compute_edge_length_ratio()
        self._compute_heights_and_ratio()
        q_alpha_min = 2.0 * self._delta / math.pi
        q_als = q_alpha_min * self._lmin_ratio
        return q_als

    def compute_plot_point(self):
        """Computes the plot point.

        Returns:
            (tuple): The plot point as a tuple of x,y coordinates.
        """
        # compute the ARR point for the triangle
        beta_factor = 0.5
        gamma_factor = 5.0
        ave_ratio = 0.5 * (self._lmin_ratio + self._h_min_ratio)
        tmp_alpha_val = self._alpha / math.pi
        tmp_beta_val = (self._beta + self._gamma) / 2.0 / math.pi
        tmp_gamma_val = self._delta / math.pi * ave_ratio
        quad_variability = tmp_alpha_val - tmp_gamma_val
        tmp_beta_val /= (1.0 + quad_variability * beta_factor)
        tmp_gamma_val /= (1.0 + quad_variability * gamma_factor)
        # Normalize
        mag = tmp_alpha_val + tmp_beta_val + tmp_gamma_val
        tmp_alpha_val /= mag
        tmp_beta_val /= mag
        x = tmp_alpha_val
        y = tmp_beta_val
        return x, y


class PlotPt:
    """Structure to store the points we want in the plot."""
    def __init__(self, elem_id=None, q1=None, q2=None):
        """Initializes the class."""
        self.elem_id = elem_id
        self.q1 = q1
        self.q2 = q2
