"""AdvectiveTimeStepTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_VECTORS, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.datasets.advective_time_step import advective_time_step


ARG_INPUT_DATASET = 0
ARG_INPUT_COURANT_NUMBER = 1
ARG_OUTPUT_DATASET = 2


class AdvectiveTimeStepTool(Tool):
    """Tool to compute advective time step dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Advective Time Step')
        self._input_dset = None
        self._builder = None
        self._dataset_vals = None
        self._dset_grid = None
        self._courant_number = 1.0

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='in_dataset', description='Input dataset',
                                  filters=[ALLOW_ONLY_VECTORS, ALLOW_ONLY_POINT_MAPPED]),
            self.float_argument(name='courant_number', description='Use courant number', value=1.0,
                                min_value=0.0),
            self.dataset_argument(name='output_dataset', description='Advective time step dataset',
                                  value='', io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        self._input_dset = self._validate_input_dataset(arguments[ARG_INPUT_DATASET], errors)
        if self._input_dset is not None:
            self._dset_grid = self.get_input_dataset_grid(arguments[ARG_INPUT_DATASET].value)
            has_cells = False
            if self._dset_grid is not None:
                has_cells = self._dset_grid.ugrid.cell_count > 0
            if not has_cells:
                errors[arguments[ARG_INPUT_DATASET].name] = 'Grid has no cells.'
        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        pass

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Set up some of the tool arguments
        self._courant_number = float(arguments[ARG_INPUT_COURANT_NUMBER].text_value)

        # Grab the geometry associated with the input dataset, and validate if there, and if vector but no cells
        dset_ugrid = self._dset_grid.ugrid

        # Calculate the advective time step
        self._builder = advective_time_step(f'{arguments[ARG_OUTPUT_DATASET].text_value}', self._input_dset, dset_ugrid,
                                            self._courant_number, self.logger)

        # Write out the dataset
        self._add_output_datasets()

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Adding output dataset...')
        if self._builder is not None:
            # Send the dataset back to XMS
            self.set_output_dataset(self._builder)
