"""ReprojectRasterTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules

# 3. Aquaveo modules
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.utilities import gdal_utils as gu
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_INPUT_RASTER = 0
ARG_INPUT_PROJECTION_FILE = 1
ARG_OUTPUT_RASTER = 2


class ReprojectRasterTool(Tool):
    """Tool to trim a raster via a coverage with polygons, writing out a new raster."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Reproject Raster')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='input_raster', description='Input raster'),
            self.file_argument(name='input_projection_file', description='Select .prj file (WKT or PROJ.4 format)',
                               io_direction=IoDirection.INPUT),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate the projection file
        projection_file = arguments[ARG_INPUT_PROJECTION_FILE].value
        if not projection_file:
            msg = 'No projection file has been selected. Please select a valid projection (.prj) file.'
            errors[arguments[ARG_INPUT_PROJECTION_FILE].name] = msg
        else:
            wkt = gu.read_projection_file(projection_file)
            if not gu.valid_wkt(wkt):
                msg = 'The projection file is not valid. Check to make sure this is a valid projection (.prj) file.'
                errors[arguments[ARG_INPUT_PROJECTION_FILE].name] = msg
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_raster = self.get_input_raster_file(arguments[ARG_INPUT_RASTER].value)
        projection_file = arguments[ARG_INPUT_PROJECTION_FILE].value
        output_raster = self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value)
        file_wkt = gu.read_projection_file(projection_file)
        vertical_datum = self.vertical_datum
        vertical_units = self.vertical_units
        sr = gu.wkt_to_sr(file_wkt)
        if sr.IsCompound():
            vertical_datum = gu.get_vert_datum_from_wkt(file_wkt)
            vertical_units = gu.get_vert_unit_from_wkt(file_wkt)
        output_raster = ru.reproject_raster(input_raster, output_raster, file_wkt, vertical_datum, vertical_units)
        self.set_output_raster_file(output_raster, arguments[ARG_OUTPUT_RASTER].text_value)
