"""CellQualityTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from sys import float_info
import uuid

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.constraint import UnconstrainedGrid
from xms.datasets.dataset_writer import DatasetWriter
from xms.tool_core import IoDirection, Tool

# 4. Local modules
import xms.tool.algorithms.ugrids.cell_2d_quality_metrics as metrics


# Constants used in calculating Verdict results
DOUBLE_MAX = float_info.max
DOUBLE_MIN = float_info.min
DEFAULT_PRECISION = 1e-09

# Constants for identifying desired metrics
DATASET_KEYWORDS = [
    'quality_alpha_min',
    'quality_lmax_lmin',
    'quality_als',
    'quality_rr',
    'quality_lr',
    'quality_lh',
    'quality_condition',
    'quality_shear',
]


class CellQualityTool(Tool):
    """Tool to calculate quality metrics of UGrid cells."""
    ARG_INPUT_GRID = 0
    ARG_OUTPUT_GRID = 1
    ARG_QUALITY_ALPHA_MIN = 2    # Minimum interior angle (XMS)
    ARG_QUALITY_LMAX_LMIN = 3    # Edge length ratio (XMS)
    ARG_QUALITY_ALS = 4          # Area/total edge length squared ratio (XMS)
    ARG_QUALITY_RR = 5           # Inner/outer radius ratio (XMS)
    ARG_QUALITY_LR = 6           # Inner radius/maximum length ratio (XMS)
    ARG_QUALITY_LH = 7           # Minimum height/maximum length ratio (XMS)
    ARG_QUALITY_CONDITION = 8    # Condition (Verdict)
    ARG_QUALITY_SHEAR = 9        # Shear (Verdict)

    # Precision required for the Grid; this is currently not set, but may
    # potentially become one sometime in the future.
    GRID_PRECISION = DEFAULT_PRECISION

    DEFAULT_DATASET_NAMES = ['q_alpha_min', 'q_Ll', 'q_ALS', 'q_Rr', 'q_Lr', 'q_Lh', 'q_condition', 'q_shear']

    def __init__(self):
        """Initialize the class."""
        super().__init__(name='Cell Quality')
        self._args = None
        self._input_cogrid = None
        self._out_grid = None
        self._ug_name = ''
        self._dataset_names = self.DEFAULT_DATASET_NAMES
        self._datasets = []

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid'),
            self.grid_argument(name='output_grid', description='Output grid name', io_direction=IoDirection.OUTPUT),
            self.dataset_argument(name='q_alpha_min', description='Alpha Min dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_ALPHA_MIN]),
            self.dataset_argument(name='q_ll', description='Lmax/Lmin dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_LMAX_LMIN]),
            self.dataset_argument(name='q_als', description='ALS dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_ALS]),
            self.dataset_argument(name='q_rr', description='Outer/Inner Radius Ratio dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_RR]),
            self.dataset_argument(name='q_lr', description='Lmax/Inner Radius Ratio dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_LR]),
            self.dataset_argument(name='q_lh', description='Hmin/Lmax Ratio dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_LH]),
            self.dataset_argument(name='q_condition', description='Condition dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_CONDITION]),
            self.dataset_argument(name='q_shear', description='Shear dataset name',
                                  optional=True, io_direction=IoDirection.OUTPUT,
                                  value=self.DEFAULT_DATASET_NAMES[metrics.XM_QUALITY_SHEAR]),
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Make sure output name specified
        self._validate_input_grid(errors, arguments[self.ARG_INPUT_GRID])

        if arguments[self.ARG_OUTPUT_GRID].text_value == '':
            errors[arguments[self.ARG_OUTPUT_GRID].name] = 'Grid name not specified.'

        self._dataset_names = [
            arguments[self.ARG_QUALITY_ALPHA_MIN].text_value,
            arguments[self.ARG_QUALITY_LMAX_LMIN].text_value,
            arguments[self.ARG_QUALITY_ALS].text_value,
            arguments[self.ARG_QUALITY_RR].text_value,
            arguments[self.ARG_QUALITY_LR].text_value,
            arguments[self.ARG_QUALITY_LH].text_value,
            arguments[self.ARG_QUALITY_CONDITION].text_value,
            arguments[self.ARG_QUALITY_SHEAR].text_value,
        ]

        return errors

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        if argument.text_value == '':
            errors[argument.name] = 'Grid name not specified.'
            return

        self._ug_name = argument.text_value
        self._input_cogrid = self.get_input_grid(self._ug_name)

        if not self._input_cogrid:
            errors[argument.name] = 'Could not read grid.'
            return

        if not self._input_cogrid.check_all_cells_2d():
            errors[argument.name] = 'Grid cells must all be 2D.'
            return

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._args = arguments
        self._get_inputs()
        self._create_output_grid()
        self._calculate_quality_metrics()
        self._set_outputs()

    def _set_outputs(self):
        """Set the outputs from the tool."""
        self.set_output_grid(self._out_grid, self._args[self.ARG_OUTPUT_GRID])
        for ds in self._datasets:
            self.set_output_dataset(ds)

    def _get_inputs(self):
        """Get the inputs from the arguments."""
        user_input = self._args[self.ARG_OUTPUT_GRID].text_value
        self._args[self.ARG_OUTPUT_GRID].value = user_input if user_input else os.path.basename(self._ug_name)

    def _create_output_grid(self):
        """Create the output grid."""
        self._out_grid = UnconstrainedGrid(ugrid=self._input_cogrid.ugrid)
        self._out_grid.uuid = str(uuid.uuid4())

    def _calculate_quality_metrics(self):
        """Find various quality metrics for cells and appends to output."""
        ugrid = self._out_grid.ugrid

        # A list of metrics for each cell
        per_cell_quality_metrics = []
        per_cell_activities = []
        wanted_metrics = {DATASET_KEYWORDS[metric]: True for metric, name in enumerate(self._dataset_names) if name}

        if (len(wanted_metrics) == 6) and not (wanted_metrics.keys() & {'quality_condition', 'quality_shear', }):
            wanted_metrics = {'selected_metrics': False}
        else:
            wanted_metrics['selected_metrics'] = True

        for cell_id in range(ugrid.cell_count):
            # This array is indexed by the constants defined for
            # this class -- for now, it ends just before "XM_QUALITY_CONDITION".
            cell_quality_metrics = [metrics.XM_QUALITY_UNSUPPORTED_GEOMETRY] * metrics.XM_QUALITY_END

            vertices = ugrid.get_cell_locations(cell_id)
            if len(vertices) == 3:
                (cell_quality_metrics[metrics.XM_QUALITY_ALPHA_MIN],
                 cell_quality_metrics[metrics.XM_QUALITY_LMAX_LMIN],
                 cell_quality_metrics[metrics.XM_QUALITY_ALS],
                 cell_quality_metrics[metrics.XM_QUALITY_RR],
                 cell_quality_metrics[metrics.XM_QUALITY_LR],
                 cell_quality_metrics[metrics.XM_QUALITY_LH],
                 cell_quality_metrics[metrics.XM_QUALITY_CONDITION],
                 cell_quality_metrics[metrics.XM_QUALITY_SHEAR], ) = metrics.triangle_quality(
                     *vertices, **wanted_metrics)
            if len(vertices) == 4:
                (cell_quality_metrics[metrics.XM_QUALITY_ALPHA_MIN],
                 cell_quality_metrics[metrics.XM_QUALITY_LMAX_LMIN],
                 cell_quality_metrics[metrics.XM_QUALITY_ALS],
                 cell_quality_metrics[metrics.XM_QUALITY_RR],
                 cell_quality_metrics[metrics.XM_QUALITY_LR],
                 cell_quality_metrics[metrics.XM_QUALITY_LH],
                 cell_quality_metrics[metrics.XM_QUALITY_CONDITION],
                 cell_quality_metrics[metrics.XM_QUALITY_SHEAR], ) = metrics.quad_quality(
                     *vertices, **wanted_metrics)

            per_cell_quality_metrics.append(cell_quality_metrics)
            per_cell_activities.append([metrics.XM_ACTIVITY_INVALID
                                        if m in {metrics.XM_QUALITY_UNSUPPORTED_GEOMETRY,
                                                 metrics.XM_QUALITY_UNDEFINED, }
                                        else metrics.XM_ACTIVITY_VALID
                                        for m in cell_quality_metrics])
        per_cell_quality_metrics = np.array(per_cell_quality_metrics, np.float64)
        per_cell_activities = np.array(per_cell_activities, np.float64)
        for metric, dataset_name in enumerate(self._dataset_names):
            if dataset_name:
                values = per_cell_quality_metrics[:, metric]
                activities = per_cell_activities[:, metric]
                ds_builder = DatasetWriter(
                    name=dataset_name,
                    geom_uuid=self._out_grid.uuid,
                    location='cells',
                )
                ds_builder.append_timestep(0.0, values, activities)
                ds_builder.appending_finished()
                self._datasets.append(ds_builder)
