"""VrtFromRastersTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.gdal.rasters import RasterReproject
from xms.gdal.rasters.raster_utils import copy_and_compress_raster, fix_data_types, make_raster_projections_consistent
from xms.gdal.utilities import gdal_wrappers as gw
from xms.tool_core import IoDirection, Tool
from xms.tool_core.tool import equivalent_arguments

# 4. Local modules

ARG_OUTPUT_VRT = 0
ARG_RESAMPLE_ALGORITHM = 1
ARG_RASTER_1 = 2
ARG_RASTER_2 = 3


class VrtFromRastersTool(Tool):
    """Tool to create a VRT file from multiple rasters."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='VRT from Rasters')
        self._file_count = 0
        self._input_raster_filenames = []

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='vrt_file', description='Virtual Raster Table (VRT) file',
                                 io_direction=IoDirection.OUTPUT),
            self.string_argument(name='resample_algorithm', description='Resample algorithm', value='bilinear',
                                 choices=['bilinear', 'nearest_neighbor']),
            self.raster_argument(name='raster_1', description='Raster 1 (highest priority)'),
            self.raster_argument(name='raster_2', description='Raster 2', optional=True),
        ]
        return arguments

    @staticmethod
    def input_raster_count(arguments):
        """Get the number of input raster arguments.

        Doesn't check to see if argument is set to a value.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        return len(arguments) - ARG_RASTER_1

    def add_input_raster(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        raster_count = self.input_raster_count(arguments)

        # add new last raster argument
        raster_count += 1
        raster_argument = self.raster_argument(name=f'raster_{raster_count}', description=f'Raster {raster_count}',
                                               optional=True, value=None)
        arguments.append(raster_argument)

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        last_raster = arguments[-1]
        if last_raster.value is not None and last_raster.value != '':
            self.add_input_raster(arguments)

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate rasters are specified and match
        self._validate_input_rasters(arguments, self._get_input_raster_names(arguments))
        return errors

    def validate_from_history(self, arguments):
        """Called to determine if arguments are valid from history.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (bool): True if no errors, False otherwise.
        """
        # Make sure there are 6 or more arguments
        default_arguments = self.initial_arguments()
        default_length = len(default_arguments)
        if len(arguments) < default_length:
            return False
        if not equivalent_arguments(arguments[0:default_length], default_arguments):
            return False
        for i in range(len(default_arguments), len(arguments)):
            if arguments[i].io_direction != IoDirection.INPUT:
                return False
            if arguments[i].type != 'raster':
                return False
        return True

    def _validate_input_rasters(self, arguments, rasters):
        """Validate input rasters are specified.

        Args:
            arguments (list): The tool arguments.
            rasters (list): The input raster names.
        """
        for raster_text in rasters:
            key = None
            for argument in arguments:
                if argument.value == raster_text:
                    key = argument.name
            if key is not None:
                raster_filename = self.get_input_raster_file(raster_text)
                if raster_filename and os.path.isfile(raster_filename):
                    self._input_raster_filenames.append(raster_filename)

    @staticmethod
    def _get_input_raster_names(arguments):
        """Get the input rasters and blend distances for primaries.

        Args:
            arguments(list[str]): The arguments.

        Returns:
            list: List of the input raster names
        """
        raster_names = []
        for i in range(ARG_RASTER_1, len(arguments)):
            if arguments[i].text_value != '':
                raster = arguments[i].value
                raster_names.append(raster)
        return raster_names

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        vrt_filename = self.get_output_raster(arguments[ARG_OUTPUT_VRT].value, '.vrt')
        self.logger.info('Generating merged virtual raster from inputs...')
        self._input_raster_filenames.reverse()
        self._input_raster_filenames = fix_data_types(self._input_raster_filenames)
        algorithm = arguments[ARG_RESAMPLE_ALGORITHM].text_value
        self._input_raster_filenames = make_raster_projections_consistent(
            self._input_raster_filenames, RasterReproject.GRA_Cubic if algorithm == 'bilinear'
            else RasterReproject.GRA_NearestNeighbour, self.default_wkt, self.vertical_datum, self.vertical_units)
        input_files = self._input_raster_filenames.copy()
        for idx, filename in enumerate(self._input_raster_filenames):
            raster_name = f'{os.path.splitext(vrt_filename)[0]}_{idx + 1}.tif'
            if copy_and_compress_raster(filename, raster_name):
                input_files[idx] = raster_name
        gw.gdal_build_vrt(input_files, vrt_filename, options=['-overwrite'])
        self.set_output_raster_file(vrt_filename, arguments[ARG_OUTPUT_VRT].text_value, 'VRT', '.vrt')
