"""Tests for MergeGridTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
from enum import auto, Enum
import math
import string

# 2. Third party modules
from osgeo import gdal, ogr, osr

# 3. Aquaveo modules

# 4. Local modules

# XMS units
UNITS_FEET_INT = 'FEET (INTERNATIONAL)'
UNITS_FEET_US_SURVEY = 'FEET (U.S. SURVEY)'
UNITS_METERS = 'METERS'
UNITS_DEGREES = 'DEGREES'
UNITS_UNKNOWN = 'UNKNOWN'

DATUM_NGVD29 = 'NGVD29'
DATUM_NAVD88 = 'NAVD88'
DATUM_LOCAL = 'LOCAL'


class GdalResult(Enum):
    """Result codes for utility functions."""
    success = auto()
    bad_horizontal_units = auto()
    bad_horizontal_datum = auto()
    bad_zone = auto()
    unknown = auto()


def _is_vertical(wkt):
    """Determines whether this wkt has vertical information.

    Args:
        wkt (str):  A WKT1 projection string.

    Returns:
        (bool):  Whether this projection has vertical information.
    """
    if wkt.find('VERT_CS[') != -1:
        return True
    if wkt.find('VERTCS[') != -1:
        return True
    if wkt.find('VERTCRS[') != -1:
        return True
    return False


def _get_vertical_location(wkt):
    """Finds the location of the vertical projection information in this WKT string.

    Args:
        wkt (str):  A WKT1 projection string.

    Returns:
        (int):  The location of the start of the vertical projection information in the string or -1 if not found.
    """
    pos = wkt.find(',VERTCS')
    if pos == -1:
        pos = wkt.find(',VERT_CS')
    if pos == -1:
        pos = wkt.find(',VERTCRS')
    if pos == -1:
        pos = wkt.find(', VERTCS')
    if pos == -1:
        pos = wkt.find(', VERT_CS')
    if pos == -1:
        pos = wkt.find(', VERTCRS')
    if pos == -1:
        pos = wkt.find('VERTCS')
    if pos == -1:
        pos = wkt.find('VERT_CS')
    if pos == -1:
        pos = wkt.find('VERTCRS')
    return pos


def wkt_to_sr(wkt):
    """Converts the WKT to an osr.SpatialReference.

    Args:
        wkt (str):  A WKT or PROJ.4 projection string.

    Returns:
        (osr.SpatialReference):  An osr.SpatialReference generated from the WKT string.
    """
    sr = osr.SpatialReference()
    err = sr.ImportFromESRI([wkt])
    if err != ogr.OGRERR_NONE:  # Try using the generic well-known text import.
        err = sr.ImportFromWkt(wkt)
    if err != ogr.OGRERR_NONE:
        err = sr.ImportFromProj4(wkt)
    if err != ogr.OGRERR_NONE:  # GDAL does not support this WKT string
        raise ValueError(f'Unable to load WKT: {wkt}')
    return sr


def is_int(value):
    """Checks if the value can be converted to an integer."""
    try:
        int(value)
        return True
    except ValueError:
        return False


def wkt_from_epsg(epsg_code):
    """Converts the EPSG code to a WKT.

    Args:
        epsg_code (string or int):  A string or integer representing the EPSG code, ESRI code, or the IGNF code.

    Returns:
        (bool, string):  A True/False response to calling this function and desired the WKT string.
    """
    sr = osr.SpatialReference()
    err = ogr.OGRERR_FAILURE
    if is_int(epsg_code):
        err = sr.ImportFromEPSG(int(epsg_code))
    if err != ogr.OGRERR_NONE:  # Try using the generic well-known text import.
        err = sr.SetFromUserInput(f'ESRI:{epsg_code}')
    if err != ogr.OGRERR_NONE:
        err = sr.SetFromUserInput(f'IGNF:{epsg_code}')
    if err != ogr.OGRERR_NONE:  # GDAL does not support this WKT string
        return False, sr.ExportToWkt()
    return True, sr.ExportToWkt()


def valid_wkt(wkt):
    """Checks if the WKT is valid.

    Args:
        wkt (str):  A WKT or PROJ.4 projection string or anything else that might be invalid.

    Returns:
        (bool):  True if the projection string is valid.
    """
    try:
        if wkt is not None:
            wkt_to_sr(wkt)
        else:
            return False
    except ValueError:
        return False
    except Exception:
        return False
    return True


def _conv_factor_to_units(factor):
    """Converts a conversion factor (float) to a unit string.

    Args:
        factor (float):  A float with the conversion factor to meters.

    Returns:
        (string):  A string representation of the units.
    """
    units = UNITS_UNKNOWN
    if math.isclose(float(osr.SRS_UL_FOOT_CONV), factor):
        return UNITS_FEET_INT
    if math.isclose(1.0, factor):
        return UNITS_METERS
    if math.isclose(float(osr.SRS_UL_US_FOOT_CONV), factor):
        return UNITS_FEET_US_SURVEY
    if math.isclose(float(osr.SRS_UA_DEGREE_CONV), factor):
        return UNITS_DEGREES
    return units


def _get_horizontal_location(wkt):
    """Finds the location of the horizontal projection information in this WKT string.

    Args:
        wkt (str):  A WKT1 projection string.

    Returns:
        (int):  The location of the start of the horizontal projection information in the string or -1 if not found.
    """
    pos = wkt.find('PROJCS')
    if pos == -1:
        pos = wkt.find('PROJ_CS')
    if pos == -1:
        pos = wkt.find('PROJCRS')
    if pos == -1:
        pos = wkt.find('GEOGCS')
    if pos == -1:
        pos = wkt.find('GEOG_CS')
    if pos == -1:
        pos = wkt.find('LOCALCS')
    if pos == -1:
        pos = wkt.find('LOCAL_CS')
    return pos


def read_projection_file(filename):
    """Reads the projection from the filename passed in.

    Args:
        filename (str):  Projection filename containing WKT or PROJ.4.

    Returns:
        (str):  The WKT or PROJ.4 string as a single line.
    """
    wkt = ''
    try:
        with open(filename, 'r') as file:
            wkt = ''.join([line.strip() for line in file])
    except UnicodeDecodeError:
        pass
    return wkt


def add_vertical_projection(sr, vertical_datum, vertical_units, set_local=True):
    """Attempts to add the vertical datum and vertical units to the given spatial reference.

    Args:
        sr (osr.SpatialReference): The spatial reference.
        vertical_datum (str): The vertical datum.
        vertical_units (str): The vertical units.
        set_local (bool): Whether to set the vertical units if the vertical_datum is local.

    Returns:
        tuple (bool, osr.SpatialReference): Whether the vertical info was added and the modified spatial reference
    """
    sr_v = osr.SpatialReference()
    add_vertical = False
    if set_local or (vertical_datum and vertical_datum != DATUM_LOCAL):
        if vertical_datum == DATUM_NAVD88:
            if vertical_units == UNITS_FEET_INT or vertical_units == UNITS_FEET_US_SURVEY:
                # NAVD 88 Height, Feet
                sr_v.ImportFromEPSG(6360)
                add_vertical = True
            elif vertical_units == UNITS_METERS:
                # NAVD 88 Height, Meters
                sr_v.ImportFromEPSG(5703)
                add_vertical = True
        elif vertical_datum == DATUM_NGVD29:
            if vertical_units == UNITS_FEET_INT or vertical_units == UNITS_FEET_US_SURVEY:
                # NGVD 29 Height, Feet
                sr_v.ImportFromEPSG(5702)
                add_vertical = True
            elif vertical_units == UNITS_METERS:
                # NGVD 29 Height, Meters
                sr_v.ImportFromEPSG(7968)
                add_vertical = True
        else:
            if vertical_units == UNITS_FEET_INT or vertical_units == UNITS_FEET_US_SURVEY:
                # Local datum, Feet
                sr_v.ImportFromWkt('VERT_CS["Local ftUS",VERT_DATUM["Local",0],UNIT["US survey foot",0.304800609601219,'
                                   'AUTHORITY["EPSG","9003"]],AXIS["Up",UP]]')
                add_vertical = True
            elif vertical_units == UNITS_METERS:
                # Local datum, Meters
                sr_v.ImportFromWkt('VERT_CS["Local Meter",VERT_DATUM["Local",0],UNIT["metre",1,'
                                   'AUTHORITY["EPSG","9001"]],AXIS["Up",UP]]')
                add_vertical = True
    if add_vertical:
        sr_h = sr
        set_compound_projection = True
        if sr_h.IsCompound():
            wkt = sr_h.ExportToWkt()
            proj_vert_units = get_vert_unit_from_wkt(wkt)
            if proj_vert_units != UNITS_UNKNOWN and proj_vert_units != vertical_units:
                wkt = strip_vertical(wkt)
                sr_h = wkt_to_sr(wkt)
            else:
                set_compound_projection = False
        if set_compound_projection:
            sr_new = osr.SpatialReference()
            cs_name = sr_h.GetName()
            if vertical_datum:
                cs_name += ' + ' + vertical_datum
            cs_name += ' + ' + vertical_units
            err = sr_new.SetCompoundCS(cs_name, sr_h, sr_v)
            if err == ogr.OGRERR_NONE:
                sr = sr_new
            else:
                add_vertical = False
    return add_vertical, sr


def strip_vertical(wkt):
    """Strip vertical information from the given WKT.

    Args:
        wkt (str): The wkt for the spatial reference.

    Returns:
        (str): The WKT without the vertical information.
    """
    loc = _get_vertical_location(wkt)
    if loc != -1:
        wkt = wkt[:loc]
    loc = _get_horizontal_location(wkt)
    if loc != -1:
        wkt = wkt[loc:]
    # Trim whitespace and commas from both sides of the WKT
    wkt = wkt.strip(',' + string.whitespace)
    # In some cases, GDAL will put the VERTCS inside the horizontal PROJCS or GEOGCS.
    # In these cases, we need to add an extra bracket at the end so the brackets match up.
    if wkt.count('[') > wkt.count(']'):
        wkt += "]"
    return wkt


def _extract_vertical_from_wkt(wkt):
    """Extracts vertical information from the given WKT.

    Args:
        wkt (str): The wkt for the spatial reference.

    Returns:
        (str): The WKT for only the vertical information.
    """
    loc = _get_vertical_location(wkt)
    if loc != -1:
        wkt = wkt[loc:]
    # Trim whitespace and commas from both sides of the WKT
    wkt = wkt.strip(',' + string.whitespace)
    # In some cases, GDAL will put the VERTCS inside the horizontal PROJCS or GEOGCS.
    # In these cases, we need to remove an extra bracket at the end so the brackets match up.
    if wkt.count(']') > wkt.count('['):
        wkt = wkt[:-1]
    return wkt


def add_vertical_to_wkt(wkt, vertical_datum, vertical_units, set_local=True):
    """Add vertical information to the given WKT.

    Args:
        wkt (str): The wkt for the spatial reference.
        vertical_datum (str): The vertical datum.
        vertical_units (str): The vertical units.
        set_local (bool): Whether to set the vertical units if the vertical_datum is local.

    Returns:
        (str): The WKT for the projection with the vertical datum and unit information.
    """
    gdal.SetConfigOption('GTIFF_REPORT_COMPD_CS', 'TRUE')
    gdal.UseExceptions()
    wkt = strip_vertical(wkt)
    sr = wkt_to_sr(wkt)
    _, sr = add_vertical_projection(sr, vertical_datum, vertical_units, set_local)
    return sr.ExportToWkt()


def get_vert_unit_from_wkt(wkt: str, horiz_to_vert: bool = False):
    """Gets the vertical units from the WKT as a string.

    Args:
        wkt (str): The wkt in WKT1 format.
        horiz_to_vert (bool): Whether to convert horizontal units to vertical units when returning from this function.

    Returns:
        (str): The string containing the vertical units.
    """
    units = UNITS_UNKNOWN
    if not _is_vertical(wkt):
        if horiz_to_vert:
            return get_horiz_unit_from_wkt(wkt)
        else:
            return units
    wkt_vert = _extract_vertical_from_wkt(wkt)
    sr_v = wkt_to_sr(wkt_vert)
    factor = sr_v.GetTargetLinearUnits(None)
    return _conv_factor_to_units(factor)


def get_horiz_unit_from_wkt(wkt):
    """Gets the horizontal units from the WKT as a string.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        (str): The string containing the horizontal units.
    """
    factor = -1.0
    wkt = strip_vertical(wkt)
    sr = wkt_to_sr(wkt)
    if sr.IsProjected() or sr.IsLocal():
        factor = sr.GetLinearUnits()
    elif sr.IsGeographic():
        factor = sr.GetAngularUnits()
    return _conv_factor_to_units(factor)


def is_geographic(wkt):
    """Determines whether the WKT represents a geographic projection.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        (bool): Whether the WKT represents a geographic projection.
    """
    wkt = strip_vertical(wkt)
    sr = wkt_to_sr(wkt)
    return not sr.IsLocal() and sr.IsGeographic()


def is_local(wkt):
    """Determines whether the WKT represents a local projection.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        (bool): Whether the WKT represents a local projection.
    """
    # XMS typically uses an empty wkt to represent a local projection. While this may technically not make sense as a
    # valid wkt, we support it here for the convenience of XMS.
    if not wkt:
        return True

    wkt = strip_vertical(wkt)
    sr = wkt_to_sr(wkt)
    return sr.IsLocal()


def _extract_wkt_names(wkt, prefix):
    """Gets the name of a WKT node.

    Args:
        wkt (str):  The original WKT.
        prefix (str):  The prefix in the WKT to read from.  Ex:  PROJCS, GEOGCS, PARAMETER, etc.

    Returns:
        (list(str)): The list of names found just after the prefix specified.
    """
    names = []
    search_str = prefix + '["'
    pos = 0
    name_index = wkt.find(search_str, pos)
    while name_index != -1:
        # find start of name
        name_index += len(search_str)
        # remove wkt prior to name
        name = wkt[name_index:]
        pos = name_index
        # find end of name
        name_index = name.find('"')
        pos += name_index
        # remove wkt after name
        name = name[:name_index]
        # Trim whitespace from both sides of name
        name = name.strip()
        names.append(name)
        name_index = wkt.find(search_str, pos)
    return names


def _vert_datum_from_string(datum):
    """Get vertical datum string from datum string.

    Args:
        datum (str):  the vertical datum string.

    Returns:
        (str): The vertical datum as a string.
    """
    # Determine the datum from the string (Case insensitive)
    if datum.lower().find('national geodetic vertical datum 1929') != -1 or \
            datum.lower().find('national_geodetic_vertical_datum_1929') != -1:
        return DATUM_NGVD29
    elif datum.lower().find('north american vertical datum 1988') != -1 or \
            datum.lower().find('north_american_vertical_datum_1988') != -1:
        return DATUM_NAVD88
    return DATUM_LOCAL


def get_vert_datum_from_wkt(wkt):
    """Gets the vertical units from the WKT as a string.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        (str): The string containing the vertical units.
    """
    datum = DATUM_LOCAL
    if not _is_vertical(wkt):
        return datum
    wkt_vert = _extract_vertical_from_wkt(wkt)
    names = []
    # find vertical datum node string (Case insensitive)
    if wkt_vert.lower().find('vert_datum') != -1:
        names = _extract_wkt_names(wkt_vert, "VERT_DATUM")
    elif wkt_vert.lower().find('vdatum') != -1:
        names = _extract_wkt_names(wkt_vert, "VDATUM")
    if names:
        datum = names[0]
    return _vert_datum_from_string(datum)


def wkt_from_geographic_horizontal_datum(horizontal_datum: str) -> tuple[GdalResult, str]:
    """
    Create a well-known-text from a horizontal datum.

    Args:
        horizontal_datum: One of 'NAD83', 'NAD27'.

    Returns:
        If successful, a tuple of (GdalResult.success, wkt), where wkt is the created well-known-text. On failure, a
        tuple of (result, ''), where result identifies the specific reason for failure.
    """
    spatial_ref = osr.SpatialReference()
    ogerr = spatial_ref.SetWellKnownGeogCS(horizontal_datum)
    if ogerr != ogr.OGRERR_NONE:
        return GdalResult.bad_horizontal_datum, ''

    wkt = spatial_ref.ExportToWkt()
    if not wkt or not isinstance(wkt, str):
        return GdalResult.unknown, ''  # The error code suggests this can happen, but there's no documentation on when.

    return GdalResult.success, wkt


def wkt_from_utm_horizontal_datum(horizontal_datum: str, coordinate_zone: int) -> tuple[GdalResult, str]:
    """
    Create a well-known-text from a horizontal datum.

    Args:
        horizontal_datum: One of 'NAD83', 'NAD27'.
        coordinate_zone: The projection's coordinate zone. Appears to be different from the coordinate zone used in
            wkt_from_stateplane_datum, but I don't know enough about it to say what it is.

    Returns:
        If successful, a tuple of (GdalResult.success, wkt), where wkt is the created well-known-text. On failure, a
        tuple of (result, ''), where result identifies the specific reason for failure.
    """
    spatial_ref = osr.SpatialReference()

    display_name = 'UTM'
    if coordinate_zone != -999:  # -999 is the data_objects N/A flag
        display_name = f'{display_name} {coordinate_zone}'
    spatial_ref.SetProjCS(display_name)

    ogerr = spatial_ref.SetWellKnownGeogCS(horizontal_datum)
    if ogerr != ogr.OGRERR_NONE:
        return GdalResult.bad_horizontal_datum, ''

    north_hemisphere = 1 if coordinate_zone > 0 else 0
    ogerr = spatial_ref.SetUTM(coordinate_zone, north_hemisphere)
    if ogerr != ogr.OGRERR_NONE:
        return GdalResult.bad_zone, ''

    wkt = spatial_ref.ExportToWkt()
    if not wkt or not isinstance(wkt, str):
        return GdalResult.unknown, ''  # The error code suggests this can happen, but there's no documentation on when.

    return GdalResult.success, wkt


def wkt_from_stateplane_horizontal_datum(horizontal_datum: str,
                                         horizontal_units: str,
                                         coordinate_zone: int = 0) -> tuple[GdalResult, str]:
    """
    Create a well-known-text from a horizontal datum.

    Args:
        horizontal_datum: One of 'NAD83', 'NAD27'.
        horizontal_units: The projection's horizontal units. One of the constants from the top of this file:
            UNITS_FEET_INT, UNITS_FEET_US_SURVEY, UNITS_METERS, UNITS_DEGREES, UNITS_UNKNOWN.
        coordinate_zone: The projection's coordinate zone. Appears to be the projection's USGS code, but GDAL's
            documentation is bad enough that it could be anything.

    Returns:
        If successful, a tuple of (GdalResult.success, wkt), where wkt is the created well-known-text. On failure, a
        tuple of (result, ''), where result identifies the specific reason for failure.
    """
    spatial_ref = osr.SpatialReference()
    is_nad83 = 1 if horizontal_datum == 'NAD83' else 0
    xms_to_gdal_units = {
        UNITS_DEGREES: ('degrees', 0.0174532925199433),
        UNITS_FEET_INT: ('Foot (International)', 0.3048),
        UNITS_FEET_US_SURVEY: ('US survey foot', 0.3048),
        UNITS_UNKNOWN: ('', 0.0),  # unknown units
    }
    if horizontal_units not in xms_to_gdal_units:
        horizontal_units = UNITS_UNKNOWN

    gdal_units, gdal_conversion = xms_to_gdal_units[horizontal_units]
    ogerr = spatial_ref.SetStatePlane(coordinate_zone, is_nad83, gdal_units, gdal_conversion)
    if ogerr != ogr.OGRERR_NONE:
        return GdalResult.bad_zone, ''

    wkt = spatial_ref.ExportToWkt()
    if not wkt or not isinstance(wkt, str):
        return GdalResult.unknown, ''  # The error code suggests this can happen, but there's no documentation on when.

    return GdalResult.success, wkt


def get_vertical_multiplier(from_units, to_units):
    """Called to get the vertical multiplier given the source and destination units.

    Args:
        from_units (str): vertical units of source.
        to_units (str): vertical units of destination.

    Returns:
        (double): Multiplier to convert.
    """
    if from_units == UNITS_FEET_INT or from_units == UNITS_FEET_US_SURVEY:
        if to_units == UNITS_METERS:
            return 0.3048
    elif from_units == UNITS_METERS:
        if to_units == UNITS_FEET_INT or to_units == UNITS_FEET_US_SURVEY:
            return 3.28083989501
    return 1.0


def _get_linear_units(wkt):
    """Gets the GDAL linear units conversion to meters.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        (float): The conversion factor to meters.
    """
    sr = wkt_to_sr(wkt)
    return sr.GetLinearUnits()


def _get_authority_code_and_name(sr: osr.SpatialReference):
    """Returns the authority code and name from the osr.SpatialReference.

    Args:
        sr (osr.SpatialReference): An osr.SpatialReference generated by a WKT string.

    Returns:
        tuple (bool, str, str): Whether this function was able to get the authority code (bool), the authority code
        (str), and the authority name (str).
    """
    auth_code = sr.GetAuthorityCode(None)
    if auth_code:
        auth_name = sr.GetAuthorityName(None)
        if auth_name:
            return True, auth_code, auth_name
    return False, '', ''


def _get_from_user_input(name_and_code):
    """Gets a WKT from an authority name:code combination.

    Args:
        name_and_code (str): Authority name and code (i.e.: EPSG:5678, IGNF:WGS84G, NKG:ETRF00, OGC:CRS83,
         ESRI:104000, etc.)

    Returns:
        tuple (bool, str): True if the user input lookup was successful (bool) and the wkt string from the name/code
        combination (str).
    """
    sr = osr.SpatialReference()
    if sr.SetFromUserInput(name_and_code) == ogr.OGRERR_NONE:
        return True, sr.ExportToWkt()
    return False, ''


def _remove_auth_code_and_name(wkt, auth_code, auth_name):
    """Removes the authority code and name from the WKT.

    Args:
        wkt (str): The wkt in WKT1 format.
        auth_code (str): The authority code
        auth_name (str): The authority name

    Returns:
        tuple (bool, str): Whether the EPSG code was removed (bool) and the wkt without the EPSG code (str).
    """
    orig_wkt = wkt
    wkt.upper()
    auth_name.upper()
    auth_code.upper()
    search = ',AUTHORITY["' + auth_name + '",' + auth_code + ']'
    pos = wkt.find(search)
    if pos != -1:
        wkt = wkt[:pos] + wkt[pos + len(search):]
        return True, wkt
    # Sometimes the EPSG code will have quotes around it; check for that
    search = ',AUTHORITY["' + auth_name + '","' + auth_code + '"]'
    pos = wkt.find(search)
    if pos != -1:
        new_wkt = wkt[:pos] + wkt[pos + len(search):]
        return True, new_wkt
    return False, orig_wkt


def remove_epsg_code_if_unit_mismatch(wkt):
    """Remove EPSG code if unit mismatch.

    Removes the EPSG code from the horizontal projection if there's a mismatch between the code and the horizontal
    WKT units.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        tuple (bool, str): Whether the EPSG code was removed (bool) and the wkt without the EPSG code if there's a
        mismatch between the code and the WKT units (str).
    """
    orig_wkt = wkt
    wkt = strip_vertical(wkt)
    sr = wkt_to_sr(wkt)
    if sr.AutoIdentifyEPSG() == ogr.OGRERR_NONE:
        ok, auth_code, auth_name = _get_authority_code_and_name(sr)
        if ok:
            ok, new_wkt = _get_from_user_input(auth_name + ':' + auth_code)
            if ok:
                orig_units = _get_linear_units(wkt)
                new_units = _get_linear_units(new_wkt)
                if not math.isclose(orig_units, new_units):
                    return _remove_auth_code_and_name(orig_wkt, auth_code, auth_name)
    return False, orig_wkt


def identify_epsg(wkt: str) -> tuple[bool, str, str]:
    """Identifies the EPSG code of the horizontal portion of a wkt string.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        tuple (bool, str, str): The authority (EPSG) code and authority name ('EPSG', 'ESRI', 'OGC') identifying the
         projection along with whether the function was successful.
    """
    auth_code = ''
    auth_name = ''
    h_wkt = strip_vertical(wkt)
    if valid_wkt(h_wkt):
        sr = wkt_to_sr(h_wkt)
    else:
        return False, '', ''
    new_sr = sr.Clone()
    continue_search = True
    if new_sr.AutoIdentifyEPSG() == ogr.OGRERR_NONE:
        ok, auth_code, auth_name = _get_authority_code_and_name(new_sr)
        if ok:
            ok, new_wkt = _get_from_user_input(auth_name + ':' + auth_code)
            if ok:
                if get_horiz_unit_from_wkt(new_wkt) == get_horiz_unit_from_wkt(sr.ExportToWkt()):
                    sr = new_sr
                    continue_search = False
    if continue_search:
        matches = new_sr.FindMatches()
        for match in matches:
            if match[1] > 40:
                new_sr = match[0]
                if get_horiz_unit_from_wkt(new_sr.ExportToWkt()) == get_horiz_unit_from_wkt(sr.ExportToWkt()):
                    sr = new_sr
                    break
    ok, auth_code, auth_name = _get_authority_code_and_name(sr)
    return ok, auth_code, auth_name


def add_hor_auth_code_and_name(wkt: str) -> str:
    """Adds the auth code and name to the horizontal portion of a wkt string.

    Args:
        wkt (str): The wkt in WKT1 format.

    Returns:
        The new WKT string, with the Authority code and name added to the horizontal portion of the string.
    """
    orig_sr = wkt_to_sr(wkt)
    sr = orig_sr
    sr_v = None
    if _is_vertical(wkt):
        wkt_vert = _extract_vertical_from_wkt(wkt)
        sr_v = wkt_to_sr(wkt_vert)
    ok, auth_code, auth_name = identify_epsg(wkt)
    sr_h = None
    if ok:
        ok, new_wkt = _get_from_user_input(auth_name + ':' + auth_code)
        if ok:
            sr_h = wkt_to_sr(new_wkt)
    if sr_h is not None and sr_v is not None:
        sr_new = osr.SpatialReference()
        cs_name = orig_sr.GetName()
        err = sr_new.SetCompoundCS(cs_name, sr_h, sr_v)
        if err == ogr.OGRERR_NONE:
            sr = sr_new
    elif sr_h:
        sr = sr_h
    return sr.ExportToWkt()


def transform_points_from_wkt(points, from_wkt, to_wkt, keep_vertical=False, set_traditional_axis_mapping=True):
    """Converts the given points from the source projection to the destination projection.

    Args:
        points (list[tuple]): The points to transform.
        from_wkt (str): The WKT of the source projection.
        to_wkt (str): The WKT of the destination projection.
        keep_vertical (bool): Whether to keep the vertical from the WKT projections.  We usually want to strip because
            of GDAL bugs converting between vertical projections.  Sometimes we need to keep the vertical projection
            because we need to convert elevations for some reason.
        set_traditional_axis_mapping (bool): Whether to set the axis mapping strategy to traditional for the spatial
            references. This returns coordinates as X,Y (Longitude,Latitude) instead of Y,X (Lat,Lon) when converting
            to/from Geographic projections.

    Returns:
        (list[tuple]): The transformed points.
    """
    coord_trans = get_coordinate_transformation(from_wkt, to_wkt, keep_vertical, set_traditional_axis_mapping)
    return transform_points(points, coord_trans)


def get_coordinate_transformation(from_wkt, to_wkt, keep_vertical=False, set_traditional_axis_mapping=True):
    """Gets the osr.CoordinateTransformation from the source projection to the destination projection.

    Args:
        from_wkt (str): The WKT of the source projection.
        to_wkt (str): The WKT of the destination projection.
        keep_vertical (bool): Whether to keep the vertical from the WKT projections.  We usually want to strip because
            of GDAL bugs converting between vertical projections.  Sometimes we need to keep the vertical projection
            because we need to convert elevations for some reason.
        set_traditional_axis_mapping (bool): Whether to set the axis mapping strategy to traditional for the spatial
            references.  This returns coordinates as X,Y (Longitude,Latitude) instead of Y,X (Lat,Lon) when converting
            to/from Geographic projections.

    Returns:
        (osr.CoordinateTransformation): The coordinate transformation.
    """
    coord_trans = None
    if from_wkt and to_wkt:  # handle None as well as empty strings
        wkt = from_wkt
        if not keep_vertical:
            wkt = strip_vertical(wkt)
        from_sr = wkt_to_sr(wkt)
        wkt = to_wkt
        if not keep_vertical:
            wkt = strip_vertical(wkt)
        to_sr = wkt_to_sr(wkt)
        if set_traditional_axis_mapping:
            from_sr.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
            to_sr.SetAxisMappingStrategy(osr.OAMS_TRADITIONAL_GIS_ORDER)
        coord_trans = osr.CreateCoordinateTransformation(from_sr, to_sr)
    return coord_trans


def transform_points(points, coord_trans):
    """Converts the given points given the coordinate transformation.

    Args:
        points (list[tuple]): The points to transform.
        coord_trans (osr.CoordinateTransformation): The coordinate transformation to use.

    Returns:
        points (list[tuple]): The transformed points.
    """
    if coord_trans is not None:
        points = coord_trans.TransformPoints(points)
    return points


def transform_point(x, y, coord_trans):
    """Converts the given points given the coordinate transformation.

    Args:
        x (float): The X location of the point to transform.
        y (float): The Y location of the point to transform.
        coord_trans (osr.CoordinateTransformation): The coordinate transformation to use.

    Returns:
        point (tuple): The transformed point.
    """
    point = (x, y)
    if coord_trans is not None:
        point = coord_trans.TransformPoint(x, y)
    return point


def set_dont_use_exceptions():
    """Set the GDAL flag to not use exceptions."""
    gdal.DontUseExceptions()


def delete_vector_file(filename, file_format='ESRI Shapefile'):
    """Deletes the given vector file with the given file_format.

    Args:
        filename (str): The file to be deleted.
        file_format (str): The format of the file to be deleted.
    """
    ogr.GetDriverByName(file_format).DeleteDataSource(filename)


def convert_lat_lon_pts_to_utm(pts):
    """Converts a list of lat/lon points to UTM.

    Args:
        pts (list): A list of lat/lon points.

    Returns:
        (list): A list of UTM points.
        (osr.SpatialReference): The spatial reference of the UTM points.
    """
    # get the bounding box of the points
    pt_x = [pt[0] for pt in pts]
    pt_y = [pt[1] for pt in pts]
    lon = (min(pt_x) + max(pt_x)) / 2.0
    lat = (min(pt_y) + max(pt_y)) / 2.0

    # get the utm zone from the middle of the box
    spatial_ref = osr.SpatialReference()
    spatial_ref.SetWellKnownGeogCS("WGS84")

    utm_spatial_ref = spatial_ref.CloneGeogCS()
    zone = int((lon + 180.0) / 6) + 1
    b_north = True if lat > 0 else False
    utm_spatial_ref.SetUTM(zone, b_north)

    # convert the points to utm
    trans = get_coordinate_transformation(spatial_ref.ExportToWkt(), utm_spatial_ref.ExportToWkt())
    out_pts = trans.TransformPoints(pts)

    return out_pts, utm_spatial_ref


def fix_wkt_and_get_unit_type(wkt: str, remove_vertical: bool) -> tuple[str, str]:
    """Fixes the WKT and returns the new WKT and the unit type."""
    # GDAL 3.x appears to have a bug where if you have different units for the horizontal and
    # vertical projections, GDAL just saves the same units out for both the horizontal and
    # vertical projections (it sometimes changes the horizontal units to match the vertical
    # units).  So just strip off the vertical projection in this case and save the horizontal
    # projection only.  We save the vertical units to the file later in this function.
    # See the test ScatterToRasterIntermediateTests::test_global_xy_meters_z_feet_asc,
    # where this is a problem.  After some fixes, the main problem happens when we try
    # to re-save the ASCII grid as a GeoTIFF using gdalwarp and gdalwarp strips off
    # the vertical projection defined in the ASCII grid and converts the units in the
    # output GeoTIFF from feet to meters.  When we read the resulting GeoTIFF into XMS,
    # XMS thinks the vertical units are in feet because the unit string defined on the
    # band is set to "ft" and there is no vertical projection defined for the GeoTIFF
    # raster.
    vert_units = get_vert_unit_from_wkt(wkt)
    vert_datum = get_vert_datum_from_wkt(wkt)
    horiz_units = get_horiz_unit_from_wkt(wkt)
    spatial_ref = wkt_to_sr(wkt)
    is_compound = spatial_ref.IsCompound()
    if remove_vertical:
        if is_compound and vert_units != horiz_units:
            wkt = strip_vertical(wkt)
    else:
        if not is_compound:
            vert_units = horiz_units
    # Get the vertical unit string.
    unit_type = get_unit_string(vert_datum, vert_units)
    return wkt, unit_type


def get_unit_string(vert_datum: str, vert_units: str) -> str:
    """Returns unit string given the vertical datum and the vertical units."""
    unit_type = ''
    if vert_units == UNITS_FEET_US_SURVEY:
        unit_type = 'ft'
    elif vert_units == UNITS_METERS:
        unit_type = 'm'
    elif vert_units == UNITS_FEET_INT:
        unit_type = 'Foot (International)'
    if unit_type:
        if vert_datum == DATUM_NAVD88:
            unit_type = unit_type + '_navd88'
        elif vert_datum == DATUM_NGVD29:
            unit_type = unit_type + '_ngvd29'
    return unit_type


def add_vertical_projection_from_unit_type(sr: osr.SpatialReference, unit_type: str) -> osr.SpatialReference:
    """Returns the compound projection for the SpatialReference given the vertical projection unit type string."""
    unit_type = unit_type.lower()
    vertical_datum = DATUM_LOCAL
    vertical_units = UNITS_UNKNOWN
    if unit_type == 'meter' or unit_type == 'metre' or unit_type == 'm':
        vertical_units = UNITS_METERS
    elif unit_type == 'foot_us' or unit_type == 'foot' or unit_type == 'ft' or unit_type == 'us survey foot':
        vertical_units = UNITS_FEET_US_SURVEY
    elif unit_type == 'foot (international)':
        vertical_units = UNITS_FEET_INT
    elif 'navd88' in unit_type:
        vertical_datum = DATUM_NAVD88
        if unit_type == 'm_navd88':
            vertical_units = UNITS_METERS
        elif unit_type == 'ft_navd88':
            vertical_units = UNITS_FEET_US_SURVEY
        elif unit_type == 'foot (international)_navd88':
            vertical_units = UNITS_FEET_INT
    elif 'ngvd29' in unit_type:
        vertical_datum = DATUM_NGVD29
        if unit_type == 'm_ngvd29':
            vertical_units = UNITS_METERS
        elif unit_type == 'ft_ngvd29':
            vertical_units = UNITS_FEET_US_SURVEY
        elif unit_type == 'foot (international)_ngvd29':
            vertical_units = UNITS_FEET_INT
    _, sr = add_vertical_projection(sr, vertical_datum, vertical_units)
    return sr
