"""InterpPriorityRastersTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import uuid

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem
from xms.gdal.rasters import RasterInput, RasterReproject
from xms.gdal.rasters.raster_utils import fix_data_types, make_raster_projections_consistent
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import gdal_wrappers as gw
from xms.tool_core import Argument, IoDirection, Tool
from xms.tool_core.tool import equivalent_arguments

# 4. Local modules

ARG_INPUT_GRID = 0
ARG_INPUT_DEFAULT_DSET = 1
ARG_OUTPUT_DATASET = 2
ARG_RESAMPLE_ALGORITHM = 3
ARG_RASTER_1 = 4
ARG_RASTER_2 = 5


class InterpPriorityRastersTool(Tool):
    """Tool to interpolate multiple rasters to a UGrid with priority."""
    DATASET_TYPE_UNDEFINED = -1
    DATASET_TYPE_CELLS = 0
    DATASET_TYPE_POINTS = 1

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Interpolate Priority Rasters')
        self._file_count = 0
        self._input_cogrid = None
        self._ugrid = None
        self._input_raster_filenames = []
        self._default_dset = None
        self._dataset_list = []
        self._dataset_type = self.DATASET_TYPE_UNDEFINED

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        datasets = self._dataset_list.copy()
        datasets.insert(0, Argument.NONE_SELECTED)
        arguments = [
            self.grid_argument(name='grid', description='Grid'),
            self.dataset_argument(name='default_value_dataset', description='Default dataset on selected grid',
                                  value=Argument.NONE_SELECTED, optional=False),
            self.dataset_argument(name='default_name', description='Output dataset name', value="new dataset",
                                  io_direction=IoDirection.OUTPUT),
            self.string_argument(name='resample_algorithm', description='Resample algorithm', value='bilinear',
                                 choices=['bilinear', 'nearest_neighbor']),
            self.raster_argument(name='raster_1', description='Raster 1 (highest priority)'),
            self.raster_argument(name='raster_2', description='Raster 2', optional=True),
        ]
        return arguments

    @staticmethod
    def input_raster_count(arguments):
        """Get the number of input raster arguments.

        Doesn't check to see if argument is set to a value.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        return len(arguments) - ARG_RASTER_1

    def add_input_raster(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.

        Returns:
            (int): The number of input raster arguments.
            Not necessarily set to a value.
        """
        raster_count = self.input_raster_count(arguments)

        # add new last raster argument
        raster_count += 1
        raster_argument = self.raster_argument(name=f'raster_{raster_count}', description=f'Raster {raster_count}',
                                               optional=True, value=None)
        arguments.append(raster_argument)

    def enable_arguments(self, arguments):
        """Called to show/hide arguments, change argument values and add new arguments.

        Args:
            arguments(list): The tool arguments.
        """
        last_raster = arguments[-1]
        if last_raster.value is not None and last_raster.value != '':
            self.add_input_raster(arguments)
        arguments[ARG_INPUT_DEFAULT_DSET].enable_for_grid(arguments[ARG_INPUT_GRID])

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate the input grid
        self._validate_input_grid(errors, arguments[ARG_INPUT_GRID])
        # Validate rasters are specified and match
        self._validate_input_rasters(errors, arguments, self._get_input_raster_names(arguments))

        # Default option
        self._default_dset = self._validate_input_dataset(arguments[ARG_INPUT_DEFAULT_DSET], errors)
        if not self._default_dset:
            errors[arguments[ARG_INPUT_DEFAULT_DSET].name] = \
                'Unable to read the selected dataset.'
        elif self._ugrid:
            if self._default_dset.num_values == self._ugrid.point_count:
                self._dataset_type = self.DATASET_TYPE_POINTS
            elif self._default_dset.num_values == self._ugrid.cell_count:
                self._dataset_type = self.DATASET_TYPE_CELLS
            else:
                errors[arguments[ARG_INPUT_DEFAULT_DSET].name] = \
                    'Number of values in default dataset must match number of points or cells ' \
                    'in target geometry.'
        return errors

    def validate_from_history(self, arguments):
        """Called to determine if arguments are valid from history.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (bool): True if no errors, False otherwise.
        """
        # Make sure there are 6 or more arguments
        default_arguments = self.initial_arguments()
        default_length = len(default_arguments)
        if len(arguments) < default_length:
            return False
        if not equivalent_arguments(arguments[0:default_length], default_arguments):
            return False
        for i in range(len(default_arguments), len(arguments)):
            if arguments[i].io_direction != IoDirection.INPUT:
                return False
            if arguments[i].type != 'raster':
                return False
        return True

    def _validate_input_grid(self, errors, argument):
        """Validate grid is specified and 2D.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            argument (GridArgument): The grid argument.
        """
        self._input_cogrid = self.get_input_grid(argument.text_value)
        if not self._input_cogrid:
            key = argument.name
            errors[key] = 'Could not read grid.'
        else:
            self._ugrid = self._input_cogrid.ugrid

    def _validate_input_rasters(self, errors, arguments, rasters):
        """Validate input rasters are specified.

        Args:
            errors (dict): Dictionary of errors keyed by argument name.
            arguments (list): The tool arguments.
            rasters (list): The input raster names.
        """
        for raster_text in rasters:
            key = None
            for argument in arguments:
                if argument.value == raster_text:
                    key = argument.name
            if key is not None:
                if raster_text == '-- None Selected --':
                    continue  # No raster has been selected for this field, may be Ok.
                raster_filename = self.get_input_raster_file(raster_text)
                if raster_filename and os.path.isfile(raster_filename):
                    self._input_raster_filenames.append(raster_filename)

    @staticmethod
    def _get_input_raster_names(arguments):
        """Get the input rasters and blend distances for primaries.

        Args:
            arguments(list[str]): The arguments.

        Returns:
            list: List of the input raster names
        """
        raster_names = []
        for i in range(ARG_RASTER_1, len(arguments)):
            if arguments[i].text_value != '':
                raster = arguments[i].value
                raster_names.append(raster)
        return raster_names

    def write_output_dataset(self, dataset_values, dataset_name, null_value):
        """Write the output dataset.

        Args:
            dataset_values (np.ndarray): The dataset values to write
            dataset_name (str): Name to give the dataset
            null_value (float): Null value to assign the dataset
        """
        self.logger.info('Writing output dataset...')
        dset_uuid = str(uuid.uuid4())
        builder = self.get_output_dataset_writer(
            name=dataset_name,
            dset_uuid=dset_uuid,
            geom_uuid=self._input_cogrid.uuid,
            null_value=null_value,
            location='points' if self._dataset_type == self.DATASET_TYPE_POINTS else 'cells',
        )
        builder.write_xmdf_dataset(times=[0.0], data=[dataset_values])
        self.set_output_dataset(builder)

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        temp_folder = filesystem.temp_filename()  # Gets deleted with process
        os.mkdir(temp_folder)
        vrt_filename = os.path.join(temp_folder, 'work.vrt')
        self.logger.info('Generating merged virtual raster from inputs...')
        self._input_raster_filenames.reverse()
        self._input_raster_filenames = fix_data_types(self._input_raster_filenames)
        algorithm = arguments[ARG_RESAMPLE_ALGORITHM].text_value
        self._input_raster_filenames = make_raster_projections_consistent(
            self._input_raster_filenames, RasterReproject.GRA_Cubic if algorithm == 'bilinear'
            else RasterReproject.GRA_NearestNeighbour)
        gw.gdal_build_vrt(self._input_raster_filenames, vrt_filename)
        raster_input = RasterInput(vrt_filename)

        # Interpolate from the VRT raster to the input grid locations
        self.logger.info('Interpolating from source raster to target UGrid...')
        if self._dataset_type == self.DATASET_TYPE_POINTS:
            locations = self._ugrid.locations
        else:
            locations = [self._ugrid.get_cell_centroid(cell_idx)[1] for cell_idx in range(self._ugrid.cell_count)]
        locations = gu.transform_points_from_wkt(locations, self.default_wkt, raster_input.wkt)
        num_points = len(locations)
        dataset_values = self._default_dset.values[0]
        self.logger.info(f'Processing point 1 of {num_points}...')
        vm = gu.get_vertical_multiplier(gu.get_vert_unit_from_wkt(raster_input.wkt), self.vertical_units)
        for i, location in enumerate(locations):
            if (i + 1) % 10000 == 0:
                self.logger.info(f'Processing point {i + 1} of {num_points}...')
            xoff, yoff = raster_input.coordinate_to_pixel(location[0], location[1])
            if xoff >= 0 and yoff >= 0:
                # in bounds
                value = raster_input.get_raster_values(xoff=xoff, yoff=yoff, xsize=1, ysize=1, resample_alg=algorithm)
                if value != raster_input.nodata_value:
                    dataset_values[i] = value * vm
        self.write_output_dataset(dataset_values, arguments[ARG_OUTPUT_DATASET].text_value,
                                  raster_input.nodata_value)
        if os.path.isfile(vrt_filename):
            os.remove(vrt_filename)
