"""Tests for MergeGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from contextlib import contextmanager
import math
import os.path
import shutil

# 2. Third party modules
import numpy as np
from osgeo import gdal, ogr, osr
import rasterio
from rasterio.windows import Window

# 3. Aquaveo modules
from xms.core.filesystem import filesystem

# 4. Local modules
from xms.gdal.rasters import RasterInput, RasterOutput, RasterReproject
from xms.gdal.rasters.raster_reproject import get_datatype_and_nodata
from xms.gdal.utilities import gdal_utils as gu, GdalRunner
from xms.gdal.utilities.gdal_wrappers import gdal_translate

DSET_NULL_VALUE = -9999999.0


def copy_raster(from_filename, to_filename, file_format='GTiff'):
    """Copies a raster from the from_filename to the to_filename.

    Args:
        from_filename (str): The full path and filename of the source file.
        to_filename (str): The full path and filename of the destination file.
        file_format (str): The short name of the file format string.  See https://gdal.org/drivers/raster/index.html
            for a list of possible strings.

    Returns:
        (bool): Whether the source raster was successfully copied to the destination filename.
    """
    if os.path.exists(from_filename):
        from_raster = RasterInput(from_filename)
        driver = gdal.GetDriverByName(file_format)
        if driver is not None:
            return copy_raster_from_raster_input(from_raster, to_filename, file_format)
    return False


def copy_and_compress_raster(from_filename: str, to_filename: str) -> bool:
    """Copies a raster from the from_filename to the to_filename.

    Args:
        from_filename (str): The full path and filename of the source file.
        to_filename (str): The full path and filename of the destination file.

    Returns:
        (bool): Whether the source raster was successfully copied to the destination filename.
    """
    if os.path.exists(from_filename):
        return gdal.Translate(to_filename, from_filename, creationOptions=["COMPRESS=LZW", "BIGTIFF=YES"]) is not None
    return False


def copy_raster_from_raster_input(from_raster, to_filename, file_format='GTiff'):
    """Copies a raster from the from_filename to the to_filename.

    Args:
        from_raster (RasterInput): The RasterInput of the source file.
        to_filename (str): The full path and filename of the destination file.
        file_format (str): The short name of the file format string.  See https://gdal.org/drivers/raster/index.html
            for a list of possible strings.

    Returns:
        (bool): Whether the source raster was successfully copied to the destination filename.
    """
    ret_value = False
    if from_raster is not None:
        driver = gdal.GetDriverByName(file_format)
        if driver is not None:
            dst_raster = driver.CreateCopy(to_filename, from_raster.gdal_raster, strict=0)
            if dst_raster is not None:
                # Once we're done, properly close the dataset
                dst_raster = None
                ret_value = True
    return ret_value


@contextmanager
def without_proj_lib():
    """Removes the PROJ_LIB environment variable from the environment.

    When using a "with" statement, this context manager takes the PROJ_LIB environment variable out of the
    OS environment and stores it in a temporary variable that's restored when the context is exited.

    The purpose of this function is to handle rasterio calls correctly since they use a different version of PROJ
    (and of the proj.db file) from our installation of GDAL (3.4.1).   Removing the PROJ_LIB environment variable
    forces rasterio to search in its own packaged location for the proj.db file instead of using the "gdalbin"
    directory used by our installation of GDAL 3.4.1 for the proj.db file.  The version of proj.db used with GDAL 3.4.1
    is incompatible with the version of proj.db used by the latest rasterio.
    """
    old_proj_lib = os.environ.pop("PROJ_LIB", None)
    try:
        yield
    finally:
        if old_proj_lib is not None:
            os.environ["PROJ_LIB"] = old_proj_lib


def add_alpha_band(input_file, output_file_with_alpha, window_size=2048):
    """Add an alpha band to a raster using rasterio, processing in larger windows for speed.

    Pixels equal to nodata get alpha 0; others get 255.

    Args:
        input_file (str): Path to input raster.
        output_file_with_alpha (str): Path to output raster with alpha band.
        window_size (int): Size of the square processing window in pixels (default 2048).
    """
    with without_proj_lib():
        with rasterio.open(input_file) as src:
            profile = src.profile.copy()
            no_data = src.nodata
            num_bands = src.count
            width, height = src.width, src.height

            # Update profile to add alpha band
            profile.update(
                driver='GTiff',
                count=num_bands + 1,
                compress='lzw',
                BIGTIFF='YES',
                tiled=True,
                blockxsize=256,
                blockysize=256
            )

            # Open output raster
            with rasterio.open(output_file_with_alpha, 'w', **profile) as dst:

                # Copy existing bands block by block for memory efficiency
                for i in range(1, num_bands + 1):
                    for y_off in range(0, height, window_size):
                        rows = min(window_size, height - y_off)
                        for x_off in range(0, width, window_size):
                            cols = min(window_size, width - x_off)
                            window = Window(x_off, y_off, cols, rows)
                            data = src.read(i, window=window)
                            dst.write(data, i, window=window)

                # Create alpha band in larger windows
                alpha_band_index = num_bands + 1
                for y_off in range(0, height, window_size):
                    rows = min(window_size, height - y_off)
                    for x_off in range(0, width, window_size):
                        cols = min(window_size, width - x_off)
                        window = Window(x_off, y_off, cols, rows)
                        data = src.read(1, window=window)
                        if no_data is not None:
                            alpha_block = np.where(data == no_data, 0, 255).astype(np.uint8)
                        else:
                            alpha_block = np.full((rows, cols), 255, dtype=np.uint8)
                        dst.write(alpha_block, alpha_band_index, window=window)


def convert_index_raster_data_to_polygons(input_raster, filename=None, ignore_nodata=False):
    """Convert the valid data areas of a raster to polygons with the index as the value, excluding the NODATA areas.

    This function assumes the raster is an index raster and converts the valid data areas of the raster to polygons.

    Args:
        input_raster (RasterInput): The input raster to be converted to polygons.
        filename (str): An optional filename used for writing a shapefile. Otherwise, the data gets written to memory.
        ignore_nodata (bool): Whether to ignore the nodata values

    Returns:
        (ogr.DataSource): The DataSource containing the layer with the boundary polygon(s).
    """
    # Polygonize the reclassified raster
    raster_sr = None
    if gu.valid_wkt(input_raster.wkt):
        raster_sr = gu.wkt_to_sr(input_raster.wkt)
    if filename:
        bounds_file = filename
        layer_name = os.path.splitext(os.path.basename(filename))[0]
    else:
        bounds_file = '/vsimem/bounds.shp'  # In-memory shapefile
        layer_name = 'bounds'
    # bounds_file = 'c:/temp/PolygonFromRasterBounds/bounds.shp'  # In-memory shapefile
    vec_ds = ogr.GetDriverByName('ESRI Shapefile').CreateDataSource(bounds_file)
    vec_lyr = vec_ds.CreateLayer(layer_name, srs=raster_sr, geom_type=ogr.wkbMultiPolygon)
    value_field = ogr.FieldDefn('VALUE', ogr.OFTInteger)
    vec_lyr.CreateField(value_field)
    classify_band = input_raster.gdal_raster.GetRasterBand(1)
    mask_band = None
    if ignore_nodata:
        # Get the raster values, and reclassify as either 1 (elevation) or 0 (nodata)
        nodata = input_raster.nodata_value
        if nodata is not None:
            raster_values = input_raster.get_raster_values()
            active_values = np.select([raster_values != nodata, raster_values == nodata], [1, 0])

            # Write to an in memory raster with the reclassify values
            ro = RasterOutput(width=input_raster.resolution[0], height=input_raster.resolution[1],
                              data_type=gdal.GDT_Int32, nodata_value=0, wkt=input_raster.wkt,
                              geotransform=input_raster.geotransform)
            reclassify_file = '/vsimem/mask_band.tif'
            # reclassify_file = 'c:/temp/PolygonFromRasterBounds/reclassify.tif'
            ro.write_raster(reclassify_file, active_values)
            ri = RasterInput(reclassify_file)
            mask_band = ri.gdal_raster.GetRasterBand(1)
    gdal.Polygonize(classify_band, mask_band, vec_lyr, 0)  # 0 for value field
    if filename:
        del vec_ds
        vec_ds = None
    return vec_ds


def convert_raster_data_to_polygons(input_raster):
    """Convert the valid data areas of a raster to polygons, excluding the NODATA areas.

    This function converts the valid data areas to polygons with values of "1" and the invalid data areas to polygons
    with values of "0" and is used for determining polygons defining valid data areas of rasters.

    Args:
        input_raster (RasterInput): The input raster to be converted to polygons.

    Returns:
        (ogr.DataSource): The DataSource containing the layer with the boundary polygon(s).
    """
    # Get the raster values, and reclassify as either 1 (elevation) or 0 (nodata)
    raster_values = input_raster.get_raster_values()
    nodata = input_raster.nodata_value
    active_values = np.select([raster_values != nodata, raster_values == nodata], [1, 0])

    # Write to an in memory raster with the reclassify values
    ro = RasterOutput(width=input_raster.resolution[0], height=input_raster.resolution[1],
                      data_type=gdal.GDT_Int32, nodata_value=0, wkt=input_raster.wkt,
                      geotransform=input_raster.geotransform)
    reclassify_file = '/vsimem/reclassify.tif'
    # reclassify_file = 'c:/temp/PolygonFromRasterBounds/reclassify.tif'
    ro.write_raster(reclassify_file, active_values)
    ri = RasterInput(reclassify_file)
    return convert_index_raster_data_to_polygons(ri)


def fill_nodata(input_raster, max_search_distance=100, smoothing_iterations=0, target_band=1, mask_band=0):
    """Calls the gdal.FillNodata function with the given parameters.

    Args:
        input_raster (RasterInput): The RasterInput for the raster to be filled.  Should be opened in 'Update' mode.
        max_search_distance (float): The maximum distance (in pixels) to search out for values to interpolate.
        smoothing_iterations (int): The number of 3x3 average filter smoothing iterations to run after the
         interpolation to dampen artifacts.
        target_band (int): The band to operate on.
        mask_band (int): The band to determine whether a cell is valid (nonzero) or invalid (0)
    """
    mask = None
    if mask_band > 0:
        mask = input_raster.gdal_raster.GetRasterBand(mask_band)
    gdal.FillNodata(input_raster.gdal_raster.GetRasterBand(target_band), mask, max_search_distance,
                    smoothing_iterations)


def get_raster_format_short_name(filename):
    """Gets the short name of the format of the given raster file.

    Args:
        filename (str): The raster file.

    Returns:
        (str): The short name of the raster file format.
    """
    from_raster = None
    try:
        from_raster = RasterInput(filename)
    except ValueError:
        from_raster = None
    if from_raster is not None:
        return from_raster.gdal_raster.GetDriver().ShortName
    else:
        return 'UNDEFINED'


def fix_data_types(filenames):
    """Given the filenames passed in, check if they all have the same datatype in the raster band.

    If any datatypes are different, this function determines the "best" datatype and converts any rasters that don't
    match the best datatype to a temporary file.
    Some GDAL utilities (like making a VRT) need to have the same raster band data types, otherwise you get an error.

    Args:
        filenames (list of str):  List of filenames to check.

    Returns:
        (array):  The list of filenames, possibly modified.
    """
    data_types = []
    return_filenames = filenames.copy()
    del_items = []
    for filename in filenames:
        # Store the datatype of each raster passed in
        try:
            raster = RasterInput(filename)
            data_types.append(raster.data_type)
        except ValueError:
            del_items.append(filename)
            continue
    for item in del_items:
        return_filenames.remove(item)
    if len(data_types) > 0 and (data_types.count(data_types[0]) == len(data_types)):
        # We've got some data types and they are all equal
        return return_filenames
    # Get the "highest" datatype from the list of data types
    # Float64 > CFloat64 > Float32 > CFloat32 > UInt64 > Int64 > UInt32 > Int32 > CInt32 > UInt16 > Int16 > CInt16 >
    # Int8 > Byte
    desirability_scores = {gdal.GDT_Float64: 11, gdal.GDT_CFloat64: 10, gdal.GDT_Float32: 9, gdal.GDT_CFloat32: 8,
                           gdal.GDT_UInt32: 7, gdal.GDT_Int32: 6, gdal.GDT_CInt32: 5, gdal.GDT_UInt16: 4,
                           gdal.GDT_Int16: 3, gdal.GDT_CInt16: 2, gdal.GDT_Byte: 1}
    type_names = {gdal.GDT_Float64: 'Float64', gdal.GDT_CFloat64: 'CFloat64', gdal.GDT_Float32: 'Float32',
                  gdal.GDT_CFloat32: 'CFloat32', gdal.GDT_UInt32: 'UInt32', gdal.GDT_Int32: 'Int32',
                  gdal.GDT_CInt32: 'CInt32', gdal.GDT_UInt16: 'UInt16', gdal.GDT_Int16: 'Int16',
                  gdal.GDT_CInt16: 'CInt16', gdal.GDT_Byte: 'Byte'}
    # Initialize variables to store the best data type and its score
    highest_type = None
    max_score = float('-inf')
    # Iterate through the data types and calculate the score for each
    for data_type in data_types:
        score = desirability_scores.get(data_type, 0)
        if score > max_score:
            max_score = score
            highest_type = data_type
    if highest_type is not None:
        # Loop through the rasters and convert them all to the highest_type
        for idx, filename in enumerate(return_filenames):
            raster = RasterInput(filename)
            if raster.data_type != highest_type:
                # Convert the raster to the highest type
                temp_file = filesystem.temp_filename(suffix='.tif')
                if gdal_translate(filename, temp_file, options=f'-of GTiff -ot {type_names[highest_type]}'):
                    return_filenames[idx] = temp_file
    return return_filenames


def make_raster_projections_consistent(filenames, resample_alg: int = gdal.GRA_NearestNeighbour,
                                       default_wkt: str = None, vertical_datum: str = None, vertical_units: str = None):
    """Given the filenames passed in, convert the projections of all the rasters to the projection of the first raster.

    If any projections are different, this function converts any rasters that don't match the projection of the first
    raster to a temporary file using the given resample_alg.
    Some GDAL utilities (like making a VRT) need to have the same raster projections, otherwise you get an error.

    Args:
        filenames (list of str):  List of filenames to check.
        resample_alg (int): The resample algorithm to use, one of gdal.GRA_*
        default_wkt (str): The WKT from the target projection.
        vertical_datum (str): The vertical datum from the target projection.
        vertical_units (str): The vertical units from the target projection.

    Returns:
        (array):  The list of filenames, possibly modified.
    """
    raster_wkts = []
    unit_types = []
    return_filenames = filenames.copy()
    del_items = []
    for filename in filenames:
        # Store the datatype of each raster passed in
        try:
            raster = RasterInput(filename)
            raster_wkts.append(raster.wkt)
            unit_types.append(raster.unit_type)
        except ValueError:
            del_items.append(filename)
            continue
    for item in del_items:
        return_filenames.remove(item)
    if len(raster_wkts) > 0:
        to_sr = None
        if default_wkt is not None:
            if gu.valid_wkt(default_wkt):
                _, default_wkt = gu.remove_epsg_code_if_unit_mismatch(default_wkt)
                to_sr = gu.wkt_to_sr(default_wkt)
            if to_sr is not None:
                _, to_sr = gu.add_vertical_projection(to_sr, vertical_datum, vertical_units)
        if to_sr is None and gu.valid_wkt(raster_wkts[0]):
            _, to_wkt = gu.remove_epsg_code_if_unit_mismatch(raster_wkts[0])
            if gu.valid_wkt(to_wkt):
                to_sr = gu.wkt_to_sr(to_wkt)
                if not to_sr.IsCompound():
                    # Check if any of the rasters have a compound projection because we need to set the target
                    # projection to a compound projection since the reproject_raster function will automatically
                    # add a vertical projection to any raster with a compound projection and all the raster projections
                    # must be the same to build a VRT (you can't have compound and non-compound projections)
                    for unit_type in unit_types:
                        if unit_type:
                            to_sr = gu.add_vertical_projection_from_unit_type(to_sr, unit_type)
                            if to_sr.IsCompound():
                                break
        if to_sr is not None and not to_sr.IsLocal():
            default_wkt = gu.strip_vertical(to_sr.ExportToWkt())
            vertical_datum = gu.get_vert_datum_from_wkt(to_sr.ExportToWkt())
            vertical_units = gu.get_vert_unit_from_wkt(to_sr.ExportToWkt())
            for idx, filename in enumerate(return_filenames):
                if gu.valid_wkt(raster_wkts[idx]):
                    from_sr = gu.wkt_to_sr(raster_wkts[idx])
                    if from_sr is not None and not from_sr.IsLocal():
                        if not to_sr.IsSame(from_sr):
                            # All the rasters need to have the exact same projection if we're building a VRT.
                            temp_file = filesystem.temp_filename(suffix='.tif')
                            reproject_raster(filename, temp_file, default_wkt, vertical_datum, vertical_units,
                                             resample_alg)
                            return_filenames[idx] = temp_file
    return return_filenames


def is_index_raster(raster):
    """Determines if the input raster is an index raster.

    Args:
        raster (RasterInput): The input raster to check.

    Returns:
        (bool): True if the raster is an index raster (like nlcd or ccap).
    """
    if type(raster) is not RasterInput:
        return False
    if raster.gdal_raster.RasterCount > 1:
        return False
    band = raster.gdal_raster.GetRasterBand(1)
    color_table = band.GetColorTable()
    if color_table is None:
        return False
    return True


def is_float_raster(raster):
    """Determines if the raster is a float raster.

    Args:
        raster (:obj:`RasterInput`): the raster

    Returns:
        (:obj:`bool`): True if the raster is a float raster
    """
    if type(raster) is not RasterInput:
        return False
    if raster.gdal_raster.RasterCount > 1:
        return False
    if raster.data_type not in [raster.GDT_Float32, raster.GDT_Float64]:
        return False
    return True


def interpolate_raster_to_points(raster_file, xy_locations, xy_locations_wkt, xy_locations_vertical_units,
                                 resample_alg='bilinear'):
    """Interpolates a raster to points.

    Args:
        raster_file (str): The raster file to interpolate.
        xy_locations (list): The list of locations to interpolate to.
        xy_locations_wkt (str): The WKT of the locations.
        xy_locations_vertical_units (str): The vertical units.
        resample_alg (str): The resampling algorithm to use. Must be one of the keys in GDAL_RESAMPLE_ALGORITHMS.

    Returns:
        tuple(list, float): The interpolated values at the locations and the no data value.
    """
    raster = RasterInput(raster_file)
    transformer = gu.get_coordinate_transformation(xy_locations_wkt, raster.wkt)
    src_units = gu.get_vert_unit_from_wkt(raster.wkt, True)
    tgt_units = xy_locations_vertical_units
    vm = gu.get_vertical_multiplier(src_units, tgt_units)

    locs = gu.transform_points(xy_locations, transformer)
    nodata_value = DSET_NULL_VALUE if raster.nodata_value is None else raster.nodata_value
    elevs = [nodata_value] * len(locs)
    for i, loc in enumerate(locs):
        x_off, y_off = raster.coordinate_to_pixel(loc[0], loc[1])
        if x_off >= 0 and y_off >= 0:
            v = raster.get_raster_values(xoff=x_off, yoff=y_off, xsize=1, ysize=1, resample_alg=resample_alg)
            if v is not None:
                elevs[i] = vm * v[0][0]

    return elevs, nodata_value


def reconcile_projection_to_raster(raster_file: str | RasterInput, default_wkt: str, vertical_datum: str,
                                   vertical_units: str) -> tuple[osr.SpatialReference, osr.SpatialReference, float]:
    """Reconciles a given projection to a raster so the raster can be warped to the returned projection.

    Args:
        raster_file (str | RasterInput): The raster file.
        default_wkt (str): The WKT from the target projection.
        vertical_datum (str): The vertical datum from the target projection.
        vertical_units (str): The vertical units from the target projection.

    Returns:
        tuple[osr.SpatialReference, osr.SpatialReference, float]: A tuple containing the target spatial reference,
        the raster spatial reference (possibly modified), and the vertical multiplier.
    """
    raster_sr = None
    target_sr = None
    vm = 1.0
    input_raster = raster_file
    if isinstance(raster_file, str):
        input_raster = RasterInput(raster_file)
    if gu.valid_wkt(input_raster.wkt):
        _, raster_wkt = gu.remove_epsg_code_if_unit_mismatch(input_raster.wkt)
        if gu.valid_wkt(raster_wkt):
            raster_sr = gu.wkt_to_sr(raster_wkt)
            if not raster_sr.IsCompound():
                raster_sr = gu.add_vertical_projection_from_unit_type(raster_sr, input_raster.unit_type)
        if gu.valid_wkt(default_wkt):
            target_sr = gu.wkt_to_sr(default_wkt)
            _, target_sr = gu.add_vertical_projection(target_sr, vertical_datum, vertical_units)
        if raster_sr is not None and target_sr is not None:
            if raster_sr.IsCompound():
                # GDAL will automatically convert the vertical values when re-projecting.
                if vertical_units == gu.UNITS_UNKNOWN:
                    vertical_units = gu.get_vert_unit_from_wkt(raster_wkt)
                _, target_sr = gu.add_vertical_projection(target_sr, vertical_datum, vertical_units)
            else:
                # In this case, we need to set the vertical units for the new raster since XMS assumes the vertical
                # units match the horizontal units.
                vertical_added = False
                if target_sr.IsCompound():
                    vertical_added, raster_sr = gu.add_vertical_projection(raster_sr,
                                                                           gu.get_vert_datum_from_wkt(default_wkt),
                                                                           gu.get_horiz_unit_from_wkt(raster_wkt))
                if not vertical_added:
                    vm = gu.get_vertical_multiplier(gu.get_horiz_unit_from_wkt(raster_wkt),
                                                    gu.get_horiz_unit_from_wkt(default_wkt))
    if target_sr is not None:
        target_wkt = gu.add_hor_auth_code_and_name(target_sr.ExportToWkt())
        target_sr = gu.wkt_to_sr(target_wkt)
    if raster_sr is not None:
        raster_wkt = gu.add_hor_auth_code_and_name(raster_sr.ExportToWkt())
        raster_sr = gu.wkt_to_sr(raster_wkt)
    return target_sr, raster_sr, vm


def set_vertical_units(reprojected_file: str, vertical_datum: str, vertical_units: str, vertical_multiplier: float,
                       output_file: str):
    """Sets the vertical units of a raster in the raster's metadata.

    Args:
        reprojected_file (str): The name of the reprojected file to edit.
        vertical_datum (str): The vertical datum.
        vertical_units (str): The vertical units to set.
        vertical_multiplier (float): The vertical multiplier.
        output_file (str): The output raster filename.
    """
    gdal_runner = GdalRunner()
    if not math.isclose(vertical_multiplier, 1.0):
        args = ['-A', reprojected_file, '--outfile=$OUT_FILE$', f'--calc=A*{vertical_multiplier}']
        vertical_corrected_file = gdal_runner.run('gdal_calc.py', 'vertical_corrected.tif', args)
        copy_and_compress_raster(vertical_corrected_file, output_file)
    units = gu.get_unit_string(vertical_datum, vertical_units)
    if units:
        args = ['-units', units, output_file]
        gdal_runner.run('gdal_edit.py', '', args)
    # Remove the folder containing all the generated files
    shutil.rmtree(gdal_runner.temp_file_path)


def reproject_raster(raster_file: str, output_file: str, default_wkt: str, vertical_datum: str,
                     vertical_units: str, resample_alg: int = gdal.GRA_NearestNeighbour) -> str:
    """Reconciles a given projection to a raster so the raster can be warped to the returned projection.

    Args:
        raster_file (str): The input raster file.
        output_file (str): The target raster file.
        default_wkt (str): The WKT from the target projection.
        vertical_datum (str): The vertical datum from the target projection.
        vertical_units (str): The vertical units from the target projection.
        resample_alg (int): The resample algorithm to use, one of gdal.GRA_*

    Returns:
        str: The output raster filename.
    """
    # If the input and output files are the same, go ahead and make a new "temporary" file.
    if os.path.exists(raster_file) and os.path.exists(output_file):
        if os.path.samefile(raster_file, output_file):
            temp_file = filesystem.temp_filename(suffix='.tif')
            copy_and_compress_raster(raster_file, temp_file)
            raster_file = temp_file
    target_sr, raster_sr, vertical_multiplier = reconcile_projection_to_raster(raster_file, default_wkt, vertical_datum,
                                                                               vertical_units)
    if target_sr is not None:
        reprojected_file = filesystem.temp_filename(suffix='.tif')
        convert_vertical = not math.isclose(vertical_multiplier, 1.0)
        if not convert_vertical:
            reprojected_file = output_file
        reproject = RasterReproject([raster_file], reprojected_file, target_sr.ExportToWkt(), raster_sr.ExportToWkt(),
                                    resample_alg)
        if reproject.run() is not None:
            set_vertical_units(reprojected_file, vertical_datum, vertical_units, vertical_multiplier, output_file)
            if convert_vertical:
                os.remove(reprojected_file)
            return output_file
    if copy_and_compress_raster(raster_file, output_file):
        return output_file
    raise RuntimeError(f'Could not create raster output file {output_file}.')


def get_datatype_and_nodata_for_warp(input_raster: RasterInput) -> tuple[int, float]:
    """Gets the datatype and nodata value from the RasterInput for use in gdal.Warp.

    Args:
        input_raster (RasterInput): The raster.

    Returns:
        tuple: The data type and nodata values to use with gdal.Warp.
    """
    return get_datatype_and_nodata(input_raster)


def raster_to_shapefile_contours(input_raster: RasterInput, values: list, shape_filename: str):
    """Converts a raster to contours and stores the contours as a shapefile.

    Args:
        input_raster (RasterInput): The raster.
        values (list): The values to contour.
        shape_filename (str): The shapefile filename.
    """
    ogr_ds = ogr.GetDriverByName("ESRI Shapefile").CreateDataSource(shape_filename)
    contour_shp = ogr_ds.CreateLayer('contour', srs=gu.wkt_to_sr(input_raster.wkt), geom_type=ogr.wkbLineStringZM)

    field_defn = ogr.FieldDefn("ID", ogr.OFTInteger)
    contour_shp.CreateField(field_defn)
    field_defn = ogr.FieldDefn("elev", ogr.OFTReal)
    contour_shp.CreateField(field_defn)
    field_defn = None

    band = input_raster.gdal_raster.GetRasterBand(1)
    values_str = ','.join(str(f) for f in values)
    options = [f'FIXED_LEVELS={values_str}', 'ID_FIELD=0', 'ELEV_FIELD=1', f'NODATA={input_raster.nodata_value}']
    gdal.ContourGenerateEx(band, contour_shp, options)
    contour_shp = None
    ogr_ds = None
