"""MergeElevationRastersTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
import os
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.rasters import RasterInput
from xms.gdal.utilities import gdal_utils as gu, GdalRunner
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_FIRST_PRIMARY = 0
ARG_FIRST_BLEND_DISTANCE = 1
ARG_FIRST_SECONDARY = 2


class MergeElevationRastersTool(Tool):
    """Tool to merge two rasters."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Merge Elevation Rasters')
        self._file_count = 0

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='raster_1', description='Raster 1'),
            self.float_argument(name='blend_distance_1', description='Blend distance', min_value=0.0),
            self.raster_argument(name='raster_2', description='Raster 2'),
            self.raster_argument(name='merged_raster', description='Merged raster', io_direction=IoDirection.OUTPUT),
        ]
        # self.enable_arguments(arguments)
        return arguments

    @staticmethod
    def input_raster_count(arguments):
        """Get the number of input raster arguments.

        Doesn't check to see if argument is set to a value.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        return len(arguments) // 2

    def add_input_raster(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        raster_count = self.input_raster_count(arguments)

        # add a blend argument for last raster
        default_blend_value = 0.0
        if len(arguments) > 3:
            default_blend_value = arguments[-3].value
        blend_argument = self.float_argument(name=f'blend_distance_{raster_count}',
                                             description=f'Blend distance {raster_count} (in primary raster units)',
                                             value=default_blend_value, min_value=0.0)
        arguments.insert(-1, blend_argument)

        # add new last raster argument
        raster_count += 1
        raster_argument = self.raster_argument(name=f'raster_{raster_count}', description=f'Raster {raster_count}',
                                               optional=True, value=None)
        arguments.insert(-1, raster_argument)

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        last_raster = arguments[-2]
        if last_raster.value is not None and last_raster.value != '':
            self.add_input_raster(arguments)

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        rasters, secondary_raster = self._get_input_rasters(arguments)
        rasters.append((secondary_raster, 0.0))
        for idx, (raster_name, _) in enumerate(rasters):
            raster = RasterInput(raster_name)
            if raster.gdal_raster.RasterCount > 2:
                errors[arguments[idx * 2].name] = 'This tool should only be used on rasters with a single dataset ' \
                                                  'and is not designed to be used with RGB images or rasters with ' \
                                                  'more than one band (dataset).'
        return errors

    def validate_from_history(self, arguments):
        """Called to determine if arguments are valid from history.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (bool): True if no errors, False otherwise.
        """
        # Make sure there is an even number of arguments
        if (len(arguments) % 2) != 0 or len(arguments) < 4:
            return False
        for i in range(self.input_raster_count(arguments) - 1):
            # Check raster input argument
            if arguments[i * 2].io_direction != IoDirection.INPUT:
                return False
            if arguments[i * 2].type != 'raster':
                return False
            # Check blend distance argument
            if arguments[i * 2 + 1].io_direction != IoDirection.INPUT:
                return False
            if arguments[i * 2 + 1].type != 'float':
                return False
        # Check the last raster input
        if arguments[-2].io_direction != IoDirection.INPUT:
            return False
        if arguments[-2].type != 'raster':
            return False
        # Check the raster output
        if arguments[-1].io_direction != IoDirection.OUTPUT:
            return False
        if arguments[-1].type != 'raster':
            return False
        return True

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        gdal_runner = GdalRunner()
        primaries, secondary_file = self._get_input_rasters(arguments)

        for primary_file, blend_distance in reversed(primaries):
            secondary_file = merge_rasters_with_blend(self, gdal_runner, primary_file, blend_distance,
                                                      secondary_file)

        # get underlayment in case of nodata area (for debugging)
        # args = f'-A {primary_alpha_file} --outfile=$OUT_FILE$ --calc=A'
        # primary_no_alpha_file = gdal_runner.run('gdal_calc.py', 'primary_no_alpha.tif', args)

        # remove alpha band because XMS does not support alpha band (for debugging)
        # args = f'-A {combined_file} --outfile=$OUT_FILE$ --calc=A'
        # combined_no_alpha_file = gdal_runner.run('gdal_calc.py', 'combined_no_alpha.tif', args)

        # Reproject raster to the display projection
        output_raster = ru.reproject_raster(secondary_file, self.get_output_raster(arguments[-1].value),
                                            self.default_wkt, self.vertical_datum, self.vertical_units)
        self.set_output_raster_file(output_raster, arguments[-1].text_value)

        # Remove the folder containing all the generated files
        shutil.rmtree(gdal_runner.temp_file_path)

    def _get_input_rasters(self, arguments):
        """Get the input rasters and blend distances for primaries.

        Args:
            arguments list[str]: The arguments.

        Returns:
            tuple[list, str]: A tuple of primary rasters and the final secondary raster. The primary list of contains
            a 2 tuple with raster file and blend distance.
        """
        primaries = []
        for i in range(0, len(arguments), 2):
            if arguments[i].text_value != '':
                raster = arguments[i].value
                blend_distance = arguments[i + 1].value
                primaries.append((self.get_input_raster_file(raster), blend_distance))
        # remove last item as first secondary raster
        secondary = []
        if len(primaries):
            secondary, _ = primaries.pop()
        return primaries, secondary


def merge_rasters_with_blend(tool, gdal_runner, primary_file, blend_distance, secondary_file):
    """Merge a primary and secondary file with a given blend distance.

    Args:
        tool (Tool): The tool.
        gdal_runner (GdalRunner): The GDAL runner.
        primary_file (str): Path to primary raster. May have double quotes if desired.
        blend_distance (float): The distance to blend the secondary inside the primary.
        secondary_file (str): Path to secondary raster. May have double quotes if desired.

    Returns:
        (str): The blended output raster file path.
    """
    # strip quotes if there are any
    primary_file = primary_file.strip('"')
    secondary_file = secondary_file.strip('"')

    primary_raster = tool.load_raster_file(primary_file)
    secondary_raster = tool.load_raster_file(secondary_file)
    primary_no_data_value = primary_raster.nodata_value

    # Get the raster bounds into the same projection
    primary_wkt_file = None
    if gu.valid_wkt(primary_raster.wkt):
        primary_wkt_file = os.path.join(gdal_runner.temp_file_path, 'primary.prj')
        with open(primary_wkt_file, 'wt') as file:
            file.write(primary_raster.wkt)
    secondary_bounds = secondary_raster.get_raster_bounds()
    secondary_bounds = gu.transform_points_from_wkt(secondary_bounds, secondary_raster.wkt, primary_raster.wkt)
    primary_bounds = primary_raster.get_raster_bounds()

    # Make a temporary copy of the primary and secondary rasters on disk, and reproject secondary to the primary
    # raster's SRS.  Get the same origin and cell sizes for both the rasters.
    x_min = min(primary_bounds[0][0], secondary_bounds[0][0])
    y_min = min(primary_bounds[0][1], secondary_bounds[0][1])
    x_max = max(primary_bounds[1][0], secondary_bounds[1][0])
    y_max = max(primary_bounds[1][1], secondary_bounds[1][1])
    args = ['-projwin', f'{x_min}', f'{y_max}', f'{x_max}', f'{y_min}', '-tr',
            f'{primary_raster.pixel_width}', f'{primary_raster.pixel_height}', '-r', 'bilinear']
    if primary_no_data_value is not None:
        args.extend(['-a_nodata', f'{primary_no_data_value}'])
    primary_template = gdal_runner.run_wrapper('gdal_translate', primary_file, 'primary_template.vrt', args)
    primary_template_raster = tool.load_raster_file(primary_template)
    pixels = primary_template_raster.resolution
    average_pixel_size = primary_template_raster.get_average_pixel_size()

    args = ['-of', 'GTiff', '-novshift', '-te', f'{x_min}', f'{y_min}', f'{x_max}', f'{y_max}', '-tr',
            f'{primary_raster.pixel_width}', f'{primary_raster.pixel_height}', '-r', 'bilinear', '-co', 'compress=LZW',
            '-co', 'BIGTIFF=YES']
    if primary_no_data_value is not None:
        args.extend(['-dstnodata', f'{primary_no_data_value}'])
    if primary_wkt_file is not None:
        args.extend(['-t_srs', primary_wkt_file])
    args.append('-overwrite')
    secondary_template = gdal_runner.run_wrapper('gdalwarp', secondary_file, 'secondary_template.tif', args)

    # Convert the vertical units if necessary.
    vm = gu.get_vertical_multiplier(gu.get_vert_unit_from_wkt(secondary_raster.wkt),
                                    gu.get_vert_unit_from_wkt(primary_raster.wkt))
    if not math.isclose(vm, 1.0):
        args = ['-A', secondary_template, '--co="compress=LZW"', '--co="BIGTIFF=YES"', '--outfile=$OUT_FILE$',
                f'--calc=A*{vm}']
        secondary_template_corrected = gdal_runner.run('gdal_calc.py', 'secondary_template_corrected.tif', args)
    else:
        secondary_template_corrected = secondary_template

    # add 1 pixel frame to outer edges
    args = ['-of', 'GTiff', '-srcwin', '-1', '-1', f'{pixels[0] + 2}', f'{pixels[1] + 2}', '-co', 'compress=LZW',
            '-co', 'BIGTIFF=YES']
    if primary_no_data_value is not None:
        args.extend(['-a_nodata', f'{primary_no_data_value}'])
    nodata_edges_file = gdal_runner.run_wrapper('gdal_translate', primary_template, 'nodata_edges.tif', args)

    # remove nodata so we can set to data
    if primary_no_data_value is not None:
        args = ['-unsetnodata', '$OUT_FILE$']
        nodata_edges_file = gdal_runner.run('gdal_edit.py', 'nodata_edges.tif', args, increment=False)

    # make invert data and nodata
    if primary_no_data_value is not None:
        args = ['-A', nodata_edges_file, '--co="compress=LZW"', '--co="BIGTIFF=YES"', '--outfile=$OUT_FILE$',
                f'--calc=999*(A=={primary_no_data_value})', f'--NoDataValue={primary_no_data_value}']
        inverted_file = gdal_runner.run('gdal_calc.py', 'inverted.tif', args)
    else:
        inverted_file = nodata_edges_file

    # find proximity to primary nodata pixels
    blend_pixels = blend_distance / average_pixel_size
    args = [inverted_file, '$OUT_FILE$', '-co', 'BIGTIFF=YES', '-co', 'compress=LZW']
    if blend_distance > 0:
        args.extend(['-maxdist', f'{blend_pixels}'])
    proximity_file = gdal_runner.run('gdal_proximity.py', 'proximity.tif', args)

    # scale proximity to 0 to 255 alpha
    args = ['-scale', '0', f'{blend_pixels}', '-ot', 'Byte', '-a_nodata', '0']
    proximity_8_frame_file = gdal_runner.run_wrapper('gdal_translate', proximity_file, 'proximity_8_frame.vrt', args)

    # remove outer 1 pixel frame no data from proximity file
    args = ['-of', 'GTiff', '-srcwin', '1', '1', f'{pixels[0]}', f'{pixels[1]}', '-co', 'compress=LZW', '-co',
            'BIGTIFF=YES']
    if primary_no_data_value is not None:
        args.extend(['-a_nodata', f'{primary_no_data_value}'])
    proximity_8_file = gdal_runner.run_wrapper('gdal_translate', proximity_8_frame_file, 'proximity_8.tif', args)

    # combine 8-bit proximity with data bands (but not as alpha yet)
    args = ['-separate', '-co', 'BIGTIFF=YES', '-co', 'compress=LZW', '-o', '$OUT_FILE$', primary_template,
            f'{proximity_8_file}']
    primary_tapered_file = gdal_runner.run('gdal_merge.py', 'primary_tapered.tif', args)

    # add alpha band to secondary
    # Note: I changed this from using gdalwarp to just adding the alpha band using code.  Using gdalwarp converted
    # the elevations from meters to feet (even though they were already feet) in Mantis issue #13788
    # This code runs faster than gdalwarp anyway.
    secondary_alpha_file = os.path.join(os.path.dirname(os.path.abspath(primary_tapered_file)), 'secondary_alpha.tif')
    ru.add_alpha_band(secondary_template_corrected, secondary_alpha_file)

    # overlay primary on secondary raster
    args = ['-of', 'GTiff', '-r', 'bilinear', '-srcalpha', '-ot', primary_raster.data_type_name, '-co', 'compress=LZW',
            '-co', 'BIGTIFF=YES']
    if primary_no_data_value is not None:
        args.extend(['-dstnodata', f'{primary_no_data_value}'])
    args.append('-overwrite')
    source_files = [secondary_alpha_file, primary_tapered_file]
    overlayed_file = gdal_runner.run_wrapper('gdalwarp', source_files, 'overlayed.tif', args)
    return overlayed_file
