"""RasterDifferenceTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import numpy as np

# 3. Aquaveo modules
from xms.gdal.rasters import raster_utils as ru, RasterReproject
from xms.gdal.rasters import RasterInput, RasterOutput
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import gdal_wrappers as gw
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_INPUT_BASE_RASTER = 0
ARG_INPUT_SECONDARY_RASTER = 1
ARG_OUTPUT_RASTER = 2


class RasterDifferenceTool(Tool):
    """Tool to diff two rasters, writing out a new raster containing the differences."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Raster Difference')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='base_raster', description='Input base raster'),
            self.raster_argument(name='secondary_raster', description='Input secondary raster'),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        base_raster = self.get_input_raster(arguments[ARG_INPUT_BASE_RASTER].value)
        secondary_raster = self.get_input_raster(arguments[ARG_INPUT_SECONDARY_RASTER].value)
        out_path = self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value)

        if self._create_diff_raster(base_raster, secondary_raster, out_path):
            out_path = ru.reproject_raster(out_path, self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value),
                                           self.default_wkt, self.vertical_datum, self.vertical_units)
            self.set_output_raster_file(out_path, arguments[ARG_OUTPUT_RASTER].text_value)

    def _create_diff_raster(self, base_raster, secondary_raster, out_path):
        """
        Create a new raster containing the difference between two rasters.

        Args:
            base_raster (RasterInput):  The base raster.
            secondary_raster (RasterInput):  The secondary raster.
            out_path (str): The path of the raster to create.

        Returns:
            bool:  True if success, else false.
        """
        # Check to see if the rasters are in the same projection
        base_wkt = base_raster.wkt
        if not gu.valid_wkt(base_wkt):
            base_wkt = None

        # Make a temporary copy of the secondary raster on disk, and reproject to the base raster's SRS
        # Make sure all rasters have the same projections, pixel dimensions, and bounds
        secondary_copy_name = '/vsimem/secondary_diff.tif'
        rval = False
        self.logger.info('Copying secondary raster...')
        if ru.copy_raster_from_raster_input(secondary_raster, secondary_copy_name):
            bounds = base_raster.get_raster_bounds()
            # secondary_temp_raster_file = filesystem.temp_filename(suffix='.tif')
            secondary_temp_raster_file = '/vsimem/secondary_temp_raster.tif'
            self.logger.info('Warping secondary raster to a temp file so it is in the same projection as the base...')
            gw.gdal_warp(secondary_copy_name, secondary_temp_raster_file, dstSRS=base_wkt, xRes=base_raster.pixel_width,
                         yRes=base_raster.pixel_height, resampleAlg=RasterReproject.GRA_Cubic,
                         outputType=secondary_raster.data_type,
                         outputBounds=(bounds[0][0], bounds[0][1], bounds[1][0], bounds[1][1]), options=['-novshift'])

            # Get the values from the rasters and write an output raster
            base_values = base_raster.get_raster_values()
            secondary_ri = RasterInput(secondary_temp_raster_file)
            secondary_values = secondary_ri.get_raster_values()
            self.logger.info('Calculating the raster difference...')
            base_nodata = base_raster.nodata_value
            secondary_nodata = secondary_raster.nodata_value
            if base_nodata:
                diff_values = np.where((base_values != base_nodata) & (secondary_values != secondary_nodata),
                                       base_values - secondary_values, base_nodata)
            else:
                diff_values = base_values - secondary_values
            self.logger.info('Writing the output...')
            output = RasterOutput(xorigin=base_raster.xorigin, yorigin=base_raster.yorigin,
                                  width=base_raster.resolution[0], height=base_raster.resolution[1],
                                  pixel_width=base_raster.pixel_width, pixel_height=base_raster.pixel_height,
                                  nodata_value=base_nodata, wkt=base_wkt)
            output.write_raster(out_path, diff_values)
            # Clear the raster's memory
            output = None
            rval = True if os.path.isfile(out_path) else False
            if rval:
                # Calculate statistics
                out_ds = RasterInput(out_path)
                self.logger.info('Computing statistics for the difference raster...')
                out_ds.force_compute_statistics()
        return rval
