"""Data2dFromData3dTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_CELL_MAPPED, ALLOW_ONLY_SCALARS, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.ugrids.data_2d_from_data_3d import Data2dFromData3d

ARG_INPUT_DATASET = 0
ARG_INPUT_NAME = 1
ARG_INPUT_HIGHEST_ACTIVE = 2
ARG_INPUT_AVERAGE = 3
ARG_INPUT_MAX = 4
ARG_INPUT_MIN = 5
ARG_INPUT_LAYERS = 6
ARG_UGRID = 7

INIT_CELL_LAYER = -9999999


class Data2dFromData3dTool(Tool):
    """Tool to convert a 3D UGrid dataset to a 2D UGrid and datasets."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='2D Data from 3D Data')
        self._in_ds = None
        self._args = None

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.dataset_argument(name='dataset', description='Cell dataset',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_CELL_MAPPED]),
            self.string_argument(name='ugrid_name', description='Name of the output ugrid', value='',
                                 optional=True),
            self.bool_argument(name='highest_active', description='Compute highest active value in column',
                               value=True),
            self.bool_argument(name='average_value', description='Compute average value in column', value=False),
            self.bool_argument(name='max_value', description='Compute maximum value in column', value=False),
            self.bool_argument(name='min_value', description='Compute minimum value in column', value=False),
            self.bool_argument(name='layers', description='Compute value for each layer in column', value=False),
            self.grid_argument(name='ugrid', description='The ugrid', hide=True, optional=True,
                               io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        self._args = arguments
        self._in_ds = self.get_input_dataset(self._args[ARG_INPUT_DATASET].value)
        # make sure the UGrid is 3D and stacked
        co_grid = self.get_input_dataset_grid(self._args[ARG_INPUT_DATASET].text_value)
        if not co_grid:
            errors[self._args[ARG_INPUT_DATASET].name] = 'Unable to read UGrid from dataset.'
        else:
            if not co_grid.check_all_cells_3d():
                errors[self._args[ARG_INPUT_DATASET].name] = 'All cells in the UGrid must be 3D.'
            elif not co_grid.check_all_cells_vertically_prismatic():
                msg = 'All cells in the UGrid must be vertically prismatic. Aborting.'
                errors[arguments[ARG_INPUT_DATASET].name] = msg
            elif co_grid.cell_layers is None:
                msg = 'Layers must be assigned to the cells of the 3D UGrid. Aborting.'
                errors[arguments[ARG_INPUT_DATASET].name] = msg
            elif min(co_grid.cell_layers) < 1:
                msg = 'Invalid layers assigned to the cells of the 3D UGrid. Aborting.'
                errors[arguments[ARG_INPUT_DATASET].name] = msg

        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # Get Algorithm Arguments
        input_dataset = self.get_input_dataset(self._args[ARG_INPUT_DATASET].value)
        co_grid = self.get_input_dataset_grid(self._args[ARG_INPUT_DATASET].text_value)
        logger = self.logger
        compute_highest_value = self._args[ARG_INPUT_HIGHEST_ACTIVE].value
        compute_average_value = self._args[ARG_INPUT_AVERAGE].value
        compute_max_value = self._args[ARG_INPUT_MAX].value
        compute_min_value = self._args[ARG_INPUT_MIN].value
        compute_each_layer = self._args[ARG_INPUT_LAYERS].value
        output_ugrid_name = self._args[ARG_INPUT_NAME].value

        # Run Algorithm
        algorithm = Data2dFromData3d(input_dataset, co_grid, logger, compute_highest_value, compute_average_value,
                                     compute_max_value, compute_min_value, compute_each_layer, output_ugrid_name)

        ugrid_name, ugrid, datasets = algorithm.data_2d_from_data_3d()

        self._set_outputs(ugrid_name, ugrid, datasets)

    def _set_outputs(self, ugrid_name, ugrid, datasets):
        """Set outputs from the tool."""
        self._args[ARG_UGRID].value = ugrid_name
        self.set_output_grid(ugrid, self._args[ARG_UGRID])
        for ds in datasets:
            self.set_output_dataset(ds)
