"""GravityWavesCourantNumberTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.datasets.coastal_dataset_calcs import CoastDatasetCalc

ARG_INPUT_DATASET = 0
ARG_INPUT_DATASET_IS_DEPTH = 1
ARG_INPUT_GRAVITY_CONST = 2
ARG_INPUT_TIMESTEP_SECONDS = 3
ARG_INPUT_MIN_DEPTH = 4
ARG_OUTPUT_DATASET = 5


class GravityWavesCourantNumberTool(Tool):
    """Tool to compute gravity waves courant number dataset."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Gravity Waves Courant Number')
        self._input_dset = None
        self._builder = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        ds_filter = [ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]
        arguments = [
            self.dataset_argument(name='in_dataset', description='Input dataset', filters=ds_filter),
            self.bool_argument(name='dataset_is_depth', description='Input dataset is depth', value=False),
            self.float_argument(name='gravity_const', description='Gravity (m/s² or ft/s²)', value=9.80665,
                                min_value=0.0),
            self.float_argument(name='timestep_seconds', description='Use time step (seconds)', value=1.0,
                                min_value=0.0),
            self.float_argument(name='min_depth', description='Minimum depth (m or ft)', value=0.1,
                                min_value=0.0),
            self.dataset_argument(name='gw_dataset', description='Gravity waves courant number dataset',
                                  value='', io_direction=IoDirection.OUTPUT),
        ]
        self.enable_arguments(arguments)
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        self._input_dset = self._validate_input_dataset(arguments[ARG_INPUT_DATASET], errors)
        msg = 'Value must be greater than 0.0.'
        if arguments[ARG_INPUT_GRAVITY_CONST].value <= 0.0:
            errors[arguments[ARG_INPUT_GRAVITY_CONST].name] = msg
        if arguments[ARG_INPUT_TIMESTEP_SECONDS].value <= 0.0:
            errors[arguments[ARG_INPUT_TIMESTEP_SECONDS].name] = msg
        if arguments[ARG_INPUT_MIN_DEPTH].value <= 0.0:
            errors[arguments[ARG_INPUT_MIN_DEPTH].name] = msg
        return errors

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        pass

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        self._setup_output_dataset_builder(arguments)
        is_depth = arguments[ARG_INPUT_DATASET_IS_DEPTH].value
        gravity_const = float(arguments[ARG_INPUT_GRAVITY_CONST].text_value)
        timestep_seconds = float(arguments[ARG_INPUT_TIMESTEP_SECONDS].text_value)
        min_depth = float(arguments[ARG_INPUT_MIN_DEPTH].text_value)
        dset_grid = self.get_input_dataset_grid(arguments[ARG_INPUT_DATASET].text_value)
        calc = CoastDatasetCalc(logger=self.logger, dataset_is_depth=is_depth, gravity_const=gravity_const,
                                timestep_seconds=timestep_seconds, min_depth=min_depth, dataset=self._input_dset,
                                co_grid=dset_grid, wkt=self.default_wkt, dataset_builder=self._builder)
        calc.gravity_waves_courant()
        self._add_output_datasets()
        return

    def _setup_output_dataset_builder(self, arguments):
        """Set up dataset builders for selected tool outputs.

        Args:
            arguments (list): The tool arguments.
        """
        # Create a place for the output dataset file
        dataset_name = arguments[ARG_OUTPUT_DATASET].text_value
        self._builder = self.get_output_dataset_writer(
            name=dataset_name,
            geom_uuid=self._input_dset.geom_uuid,
            ref_time=self._input_dset.ref_time,
            time_units=self._input_dset.time_units,
            null_value=self._input_dset.null_value,
        )

    def _add_output_datasets(self):
        """Add datasets created by the tool to be sent back to XMS."""
        self.logger.info('Adding output dataset...')
        if self._builder is not None:
            self.logger.info('Writing output gravity waves courant number dataset to XMDF file...')
            self._builder.appending_finished()
            # Send the dataset back to XMS
            self.set_output_dataset(self._builder)
