"""ExporUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
from typing import List

# 2. Third party modules

# 3. Aquaveo modules
from xms.gdal.utilities import gdal_utils as gu
from xms.grid.ugrid import UGrid
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, Argument, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.ugrids.relax_ugrid_points import RelaxUGridPoints


class RelaxUGridPointsTool(Tool):
    """Tool to relax the points in a ugrid."""

    ARG_INPUT_GRID = 0
    ARG_INPUT_RELAXTYPE = 1
    ARG_INPUT_NUM_ITERATIONS = 2
    ARG_INPUT_CONVERGENCE_DISTANCE = 3
    ARG_INPUT_LOCKED_PTS_DATASET = 4
    ARG_INPUT_SIZE_DATASET = 5
    ARG_INPUT_OPTIMIZE_TRIANGLES = 6
    ARG_OUTPUT_UGRID = 7

    def __init__(self, name='Relax UGrid Points'):
        """Initializes the class."""
        super().__init__(name=name)
        self._input_cogrid = None
        self._relax_types = ['Area', 'Angle', 'Spring', 'Lloyd']
        self._locked_pts = None
        self._pt_sizes = None

        self._force_ugrid = True
        self._geom_txt = 'UGrid'

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_ugrid', description=self._geom_txt.capitalize(),
                               io_direction=IoDirection.INPUT),
            self.string_argument(name='relax_type', description='Relaxation method', value=self._relax_types[0],
                                 choices=self._relax_types),
            self.integer_argument(name='num_iterations', description='Number of iterations', value=5, min_value=1),
            self.float_argument(name='converge_dist', description='Convergence distance (ft or m)', value=0.1,
                                min_value=1.0e-9),
            self.dataset_argument(name='lock_dataset', description='Locked nodes dataset (optional)',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED], optional=True),
            self.dataset_argument(name='size_dataset', description='Spring relax size dataset',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED], optional=True),
            self.bool_argument(name='optimize_triangles', description='Optimize triangulation', value=False),
            self.grid_argument(name='relaxed_grid', description=f'Relaxed output {self._geom_txt} name',
                               io_direction=IoDirection.OUTPUT, optional=True)
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        self._input_cogrid = self.get_input_grid(arguments[self.ARG_INPUT_GRID].text_value)
        if self._input_cogrid is None:
            msg = f'Unable to load {self._geom_txt}. Aborting'
            errors[arguments[self.ARG_INPUT_GRID].name] = msg
            return errors
        if not self._input_cogrid.check_all_cells_2d():
            msg = f'{self._geom_txt} must have only 2D cells. '
            errors[arguments[self.ARG_INPUT_GRID].name] = msg
        ug = self._input_cogrid.ugrid
        cell_types = [ug.get_cell_type(idx) for idx in range(ug.cell_count)]
        if cell_types.count(UGrid.cell_type_enum.TRIANGLE) < 2:
            msg = errors.get(arguments[self.ARG_INPUT_GRID].name, '')
            msg += (f'{self._geom_txt} does not have any triangles. '
                    'This tool relaxes points connected to triangles.')
            errors[arguments[self.ARG_INPUT_GRID].name] = msg
        locked_ds = self.get_input_dataset(arguments[self.ARG_INPUT_LOCKED_PTS_DATASET].text_value)
        if locked_ds is not None:
            self._locked_pts = locked_ds.values[0]
        size_ds = self.get_input_dataset(arguments[self.ARG_INPUT_SIZE_DATASET].text_value)
        if size_ds is not None:
            self._pt_sizes = size_ds.values[0]

        relax_type = arguments[self.ARG_INPUT_RELAXTYPE].value
        if size_ds is None and relax_type == 'Spring':
            msg = 'Spring relaxation requires a size dataset.'
            errors[arguments[self.ARG_INPUT_SIZE_DATASET].name] = msg
        return errors

    def enable_arguments(self, arguments: List[Argument]):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        arguments[self.ARG_INPUT_LOCKED_PTS_DATASET].enable_for_grid(arguments[self.ARG_INPUT_GRID])
        arguments[self.ARG_INPUT_SIZE_DATASET].enable_for_grid(arguments[self.ARG_INPUT_GRID])
        relax_methood = arguments[self.ARG_INPUT_RELAXTYPE].value
        if relax_methood == 'Spring':
            arguments[self.ARG_INPUT_SIZE_DATASET].show = True
        else:
            arguments[self.ARG_INPUT_SIZE_DATASET].show = False

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        factor = 1.0
        if gu.valid_wkt(self.default_wkt):
            # look in GdalUtility.cpp - LaUnit imUnitFromProjectionWKT(const std::string wkt)
            sr = gu.wkt_to_sr(self.default_wkt)
            if sr.IsGeographic():
                if sr.GetAngularUnitsName().upper() in ['DEGREE', 'DS']:
                    # rough conversion from meters to degrees
                    factor = 1.0 / 111000.0

        relax_methood = arguments[self.ARG_INPUT_RELAXTYPE].value
        num_iter = arguments[self.ARG_INPUT_NUM_ITERATIONS].value
        converge_dist = arguments[self.ARG_INPUT_CONVERGENCE_DISTANCE].value
        opt = arguments[self.ARG_INPUT_OPTIMIZE_TRIANGLES].value
        relaxer = RelaxUGridPoints(self.logger, self._input_cogrid, relax_methood, num_iter, converge_dist, factor,
                                   self._locked_pts, self._pt_sizes, opt)
        relaxer.relax()
        if relaxer.out_co_grid is not None:
            if not arguments[self.ARG_OUTPUT_UGRID].value:
                grid_name = os.path.splitext(os.path.basename(arguments[self.ARG_INPUT_GRID].text_value))[-1]
                arguments[self.ARG_OUTPUT_UGRID].value = f'Relaxed_{grid_name}'
            self.set_output_grid(relaxer.out_co_grid, arguments[self.ARG_OUTPUT_UGRID], force_ugrid=self._force_ugrid)
