"""ExtrudeUGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from typing import List

# 2. Third party modules

# 3. Aquaveo modules
from xms.constraint.ugrid_extrude import extrude_grid
from xms.tool_core import Argument, IoDirection, Tool

# 4. Local modules


class ExtrudeUGridTool(Tool):
    """Extrude a 2D UGrid into a new 3D UGrid."""
    ARG_INPUT_GRID = 0
    ARG_NUMBER_OF_LAYERS = 1
    ARG_FIRST_THICKNESS = 2
    ARG_LAST_THICKNESS = -2
    ARG_OUTPUT_GRID = -1
    ARG_COUNT_LESS_THICKNESS = 3

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Extrude UGrid')
        self._input_cogrid = None
        self._ugrid = None
        self._default_dset = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.grid_argument(name='input_grid', description='Input grid'),
            self.integer_argument(name='number_of_layers', description='Number of layers', value=1, min_value=1),
            self._float_argument_for_thickness(0),
            self.grid_argument(name='output_grid', description='Output grid', io_direction=IoDirection.OUTPUT),
        ]
        return arguments

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        layer_count = arguments[self.ARG_NUMBER_OF_LAYERS].value
        argument_count = len(arguments)
        current_thickness_count = argument_count - self.ARG_COUNT_LESS_THICKNESS
        if current_thickness_count > layer_count:
            # remove excess thickness arguments
            first = self.ARG_FIRST_THICKNESS + layer_count
            last = argument_count + self.ARG_LAST_THICKNESS + 1
            del arguments[first:last]
        elif current_thickness_count < layer_count:
            # add missing thickness arguments
            missing_count = layer_count - current_thickness_count
            for i in range(missing_count):
                layer_idx = layer_count - missing_count + i
                thickness_argument = self._float_argument_for_thickness(layer_idx)
                arguments.insert(self.ARG_LAST_THICKNESS + 1, thickness_argument)

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}

        # Validate primary and secondary grids are specified and 2D
        self._validate_input_grid(errors, arguments[self.ARG_INPUT_GRID])

        return errors

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        key = argument.name
        input_ugrid = self.get_input_grid(argument.text_value)
        if input_ugrid is None:
            errors[key] = 'Could not read grid.'
        else:
            if input_ugrid.check_all_cells_3d():
                errors[key] = 'Input grid is already 3D.'

    def validate_from_history(self, arguments):
        """Called to determine if arguments are valid from history.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (bool): True if no errors, False otherwise.
        """
        initial_arguments = self.initial_arguments()
        # history arguments must be at least as long as initial arguments
        if len(arguments) < len(initial_arguments):
            return False

        # check arguments before thickness
        # must do this first to make sure number of layers argument is valid
        for i in range(self.ARG_FIRST_THICKNESS):
            if not arguments[i].equivalent_to(initial_arguments[i]):
                return False

        # figure out how many thickness arguments there should be
        layer_count = arguments[self.ARG_NUMBER_OF_LAYERS].value
        argument_count = len(arguments)
        current_thickness_count = argument_count - self.ARG_COUNT_LESS_THICKNESS
        expected_thickness_count = layer_count

        # length of thickness from number of layers must match number calculated from the length of the history
        if current_thickness_count != expected_thickness_count:
            return False

        # check thickness arguments
        for i in range(layer_count):
            thickness_argument = self._float_argument_for_thickness(i)
            if not arguments[self.ARG_FIRST_THICKNESS + i].equivalent_to(thickness_argument):
                return False

        # check arguments after thickness
        after_count = -self.ARG_LAST_THICKNESS - 1
        for i in range(after_count):
            from_end_idx = -1 - i
            if not arguments[from_end_idx].equivalent_to(initial_arguments[from_end_idx]):
                return False
        return True

    def run(self, arguments: List[Argument]):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_grid = self.get_input_grid(arguments[self.ARG_INPUT_GRID].value)
        layer_count = arguments[self.ARG_NUMBER_OF_LAYERS].value
        thicknesses = []
        for i in range(layer_count):
            thicknesses.append(arguments[self.ARG_FIRST_THICKNESS + i].value)
        output_grid = extrude_grid(input_grid, thicknesses)
        self.set_output_grid(output_grid, arguments[self.ARG_OUTPUT_GRID])

    def _float_argument_for_thickness(self, i) -> Argument:
        """Build a float argument for given thickness index.

        Args:
            i (int): The thickness index.

        Returns:
            (Argument): The argument.
        """
        return self.float_argument(name=f'layer_thickness_{i + 1}',
                                   description=f'Layer {i + 1} thickness', value=0.0, min_value=0.0)
