"""WseDepthRasterFromDatasetTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.tool_core import ALLOW_ONLY_POINT_MAPPED, ALLOW_ONLY_SCALARS, Argument, IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.rasters.create_wse_depth_raster import CreateWseDepthRaster

ARG_INPUT_ELEVATION_RASTER = 0
ARG_INPUT_DATASET = 1
ARG_INPUT_TIMESTEP = 2
ARG_NUM_PIXELS_EXTRAPOLATE = 3
ARG_OUTPUT_WSE_RASTER = 4
ARG_OUTPUT_DEPTH_RASTER = 5


class WseDepthRasterFromDatasetTool(Tool):
    """Tool to make a WSE/Depth dataset raster."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='WSE/Depth Raster from Dataset')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='elevation_raster', description='Elevation Raster'),
            self.dataset_argument(name='dataset', description='WSE Dataset',
                                  filters=[ALLOW_ONLY_SCALARS, ALLOW_ONLY_POINT_MAPPED]),
            self.timestep_argument(name='time_step', description='Timestep', value=Argument.NONE_SELECTED),
            self.integer_argument(name='num_pixels_extrapolate', description='Number of Pixels to Extrapolate (Use 0 '
                                                                             'for no extrapolation)', value=0),
            self.raster_argument(name='output_wse_raster', description='Output WSE Raster',
                                 io_direction=IoDirection.OUTPUT),
            self.raster_argument(name='output_depth_raster', description='Output Depth Raster',
                                 io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        if arguments[ARG_INPUT_DATASET].value:
            arguments[ARG_INPUT_TIMESTEP].enable_timestep(arguments[ARG_INPUT_DATASET])
        else:
            arguments[ARG_INPUT_TIMESTEP].hide = True

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        arguments[ARG_INPUT_TIMESTEP].validate_timestep(arguments[ARG_INPUT_DATASET], errors)
        if self.default_wkt is None:
            errors[arguments[ARG_OUTPUT_WSE_RASTER].name] = 'Must specify an output coordinate system.'
        multiple_band_error = 'This tool should only be used on rasters with a single dataset and is not designed to ' \
                              'be used with RGB images or rasters with more than one band (dataset).'
        input_raster = self.get_input_raster(arguments[ARG_INPUT_ELEVATION_RASTER].value)
        if input_raster.gdal_raster.RasterCount > 2:
            errors[arguments[ARG_INPUT_ELEVATION_RASTER].name] = multiple_band_error
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        tv = arguments[ARG_INPUT_ELEVATION_RASTER].text_value
        elevation_raster = self.get_input_raster(tv) if arguments[ARG_INPUT_ELEVATION_RASTER].value else None
        out_wse_path = self.get_output_raster(arguments[ARG_OUTPUT_WSE_RASTER].value)
        out_depth_path = self.get_output_raster(arguments[ARG_OUTPUT_DEPTH_RASTER].value)
        co_grid = self.get_input_dataset_grid(arguments[ARG_INPUT_DATASET].text_value)
        wse_dataset = self.get_input_dataset(arguments[ARG_INPUT_DATASET].text_value)
        timestep = arguments[ARG_INPUT_TIMESTEP].get_timestep(arguments[ARG_INPUT_DATASET])
        create_wse_raster = CreateWseDepthRaster(elevation_raster, co_grid, wse_dataset, timestep,
                                                 arguments[ARG_NUM_PIXELS_EXTRAPOLATE].value, out_wse_path,
                                                 out_depth_path, self.default_wkt, self.vertical_datum,
                                                 self.vertical_units, self.logger)
        create_wse_raster.create()
        self.set_output_raster_file(out_wse_path, arguments[ARG_OUTPUT_WSE_RASTER].value)
        self.set_output_raster_file(out_depth_path, arguments[ARG_OUTPUT_DEPTH_RASTER].value)
