"""Class for reading rasters using GDAL."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math

# 2. Third party modules
import numpy as np
import osgeo.gdal as gdal

# 3. Aquaveo modules

# 4. Local modules


GDAL_RESAMPLE_ALGORITHMS = {
    'nearest_neighbor': gdal.gdalconst.GRIORA_NearestNeighbour,
    'bilinear': gdal.gdalconst.GRIORA_Bilinear,
    'cubic': gdal.gdalconst.GRIORA_Cubic,
    'cubic_spline': gdal.gdalconst.GRIORA_CubicSpline,
    'lanczos': gdal.gdalconst.GRIORA_Lanczos,
    'average': gdal.gdalconst.GRIORA_Average,
    'mode': gdal.gdalconst.GRIORA_Mode,
    'gauss': gdal.gdalconst.GRIORA_Gauss,
}


def _get_raster_min_max(ul, nx, ny, dxdu, dxdv, dydu, dydv):
    """Gets the raster min and max coordinates.

    Args:
        ul (Point tuple): The upper left coordinate
        nx (int): The number of image pixels in the X-direction
        ny (int): The number of image pixels in the Y-direction
        dxdu (double): The dxdu image rotation
        dxdv (double): The dxdv image rotation
        dydu (double): The dydu image rotation
        dydv (double): The dydv image rotation

    Returns:
        (Tuple of point tuples): The minimum [0] and maximum [1] coordinates.
    """
    ll = (ul[0] + ny * dxdv, ul[1] + ny * dydv)
    ur = (ul[0] + nx * dxdu, ul[1] + nx * dydu)
    lr = (ul[0] + nx * dxdu + ny * dxdv, ul[1] + nx * dydu + ny * dydv)
    return (min(ul[0], ll[0]), min(ll[1], lr[1]), 0.0), (max(ur[0], lr[0]), max(ul[1], ur[1]), 0.0)


class RasterInput:
    """Class for reading rasters."""

    # Returns from data_type
    GDT_Unknown = gdal.GDT_Unknown
    GDT_Byte = gdal.GDT_Byte
    GDT_UInt16 = gdal.GDT_UInt16
    GDT_Int16 = gdal.GDT_Int16
    GDT_UInt32 = gdal.GDT_UInt32
    GDT_Int32 = gdal.GDT_Int32
    GDT_Float32 = gdal.GDT_Float32
    GDT_Float64 = gdal.GDT_Float64
    GDT_CInt16 = gdal.GDT_CInt16
    GDT_CInt32 = gdal.GDT_CInt32
    GDT_CFloat32 = gdal.GDT_CFloat32
    GDT_CFloat64 = gdal.GDT_CFloat64

    def __init__(self, filename, update=False):
        """Initializes the class.

        Args:
            filename (str): raster filename.
            update (bool): Whether to open the raster file in 'Update' mode so it can be modified.
        """
        gdal.SetConfigOption('GTIFF_REPORT_COMPD_CS', 'TRUE')
        self._raster = None
        if update:
            mode = gdal.GA_Update
        else:
            mode = gdal.GA_ReadOnly
        gdal_use_exceptions = gdal.GetUseExceptions()
        gdal.UseExceptions()
        try:
            self._raster = gdal.Open(filename, mode)
        except Exception:
            if update:
                try:
                    self._raster = gdal.Open(filename, gdal.GA_ReadOnly)
                except Exception:
                    self._raster = None
        if self._raster is None:
            raise ValueError(f'Unable to open raster file: {filename}')
        if gdal_use_exceptions:
            gdal.UseExceptions()
        else:
            gdal.DontUseExceptions()
        self._geotransform = None
        self._xorigin = None
        self._yorigin = None
        self._pixel_width = None
        self._pixel_height = None
        self._nodata_value = None
        self._unit_type = None
        self._got_acting_nodata = False
        self._acting_nodata_value = None

    def _validate_resample_algorithm(self, resample_alg):
        """Ensure resample algorithm string is valid.

        Args:
            resample_alg (str): The algorithm string to verify

        Returns:
            int: The gdalconst enum associated with the string

        Raises:
            ValueError if algorithm string is invalid
        """
        resample_alg = resample_alg.lower()
        if resample_alg not in GDAL_RESAMPLE_ALGORITHMS.keys():
            raise ValueError(
                f'Invalid resample algorithm: {resample_alg}. Must be one of: {GDAL_RESAMPLE_ALGORITHMS.keys()}'
            )
        return GDAL_RESAMPLE_ALGORITHMS[resample_alg]

    @property
    def geotransform(self):
        """Returns the raster's geotransform struct."""
        if self._geotransform is None:
            self._geotransform = self._raster.GetGeoTransform()
        return self._geotransform

    @property
    def xorigin(self):
        """Returns the X-coordinate of the raster's upper left corner."""
        if self._xorigin is None:
            self._xorigin = self.geotransform[0]
        return self._xorigin

    @property
    def yorigin(self):
        """Returns the Y-coordinate of the raster's upper left corner."""
        if self._yorigin is None:
            self._yorigin = self.geotransform[3]
        return self._yorigin

    @property
    def pixel_width(self):
        """Returns the size of a pixel in the X-direction in the raster's units."""
        if self._pixel_width is None:
            self._pixel_width = self.geotransform[1]
        return self._pixel_width

    @property
    def pixel_height(self):
        """Returns the size of a pixel in the Y-direction in the raster's units."""
        if self._pixel_height is None:
            self._pixel_height = self.geotransform[5]
        return self._pixel_height

    @property
    def nodata_value(self):
        """Returns the raster's no data value."""
        if self._nodata_value is None:
            self._nodata_value = self._raster.GetRasterBand(1).GetNoDataValue()
        return self._nodata_value

    @property
    def unit_type(self):
        """Returns the raster's no data value."""
        if self._unit_type is None:
            self._unit_type = self._raster.GetRasterBand(1).GetUnitType()
        return self._unit_type

    @property
    def resolution(self):
        """Returns a tuple containing the raster width [0] and height [1] in pixels."""
        return self._raster.RasterXSize, self._raster.RasterYSize

    @property
    def wkt(self):
        """Returns the raster's well-known text."""
        return self._raster.GetProjection()

    @property
    def data_type(self):
        """Returns the raster's data type of the first band."""
        return self._raster.GetRasterBand(1).DataType

    @property
    def data_type_name(self):
        """Returns the name of the raster's data type of the first band."""
        return gdal.GetDataTypeName(self.data_type)

    @property
    def gdal_raster(self):
        """Returns the GDAL dataset for this raster."""
        return self._raster

    def get_acting_nodata_value(self):
        """Gets the acting nodata value and whether it is valid.

        Returns:
            (Tuple): The acting nodata value [0] and whether it is valid [1].
        """
        if not self.nodata_value:
            if not self._got_acting_nodata:
                z_min, z_max = self.get_z_min_max()
                if math.isclose(z_min, -32767.0) or math.isclose(z_min, -9999.0):
                    self._got_acting_nodata = True
                    self._acting_nodata_value = z_min
            if self._got_acting_nodata:
                return self._acting_nodata_value, True
        return 0.0, False

    def get_z_min_max(self):
        """Gets the raster min and max Z coordinates.

        Returns:
            (Tuple): The raster min [0] and max [1] Z coordinates.
        """
        band = self._raster.GetRasterBand(1)
        z_min = band.GetMinimum()
        z_max = band.GetMaximum()
        if not z_min or not z_max:
            return band.ComputeRasterMinMax(True)
        return z_min, z_max

    def force_compute_statistics(self):
        """Forces computing the statistics for the raster.

        Returns:
            (bool): Whether the statistics were successfully computed.
        """
        band = self._raster.GetRasterBand(1)
        return band.GetStatistics(0, 1) == gdal.CE_None

    def get_raster_bounds(self):
        """Gets the raster min and max coordinates.

        Returns:
            (Tuple): The raster min [0] and max [1] X, Y, Z coordinates.  The Z coordinates are always 0.0 (not set).
        """
        width, height = self.resolution
        if width != 0:
            transform = self.geotransform
            if transform:
                ul = (self.xorigin, self.yorigin)
                return _get_raster_min_max(ul, width, height, self.pixel_width, transform[2], transform[4],
                                           self.pixel_height)
        return (0.0, 0.0, 0.0), (0.0, 0.0, 0.0)

    def get_average_pixel_size(self):
        """Gets the average pixel size.

        Returns:
            (float): The average of the horizontal and vertical pixel size.
        """
        num_pixels = self.resolution
        bounds = self.get_raster_bounds()
        horizontal = (bounds[1][0] - bounds[0][0]) / num_pixels[0]
        vertical = (bounds[1][1] - bounds[0][1]) / num_pixels[1]
        average = (horizontal + vertical) * 0.5
        return average

    def coordinate_to_pixel(self, x, y):
        """Find the pixel containing a point coordinate.

        Args:
            x (float): X-coordinate of the location
            y (float): Y-coordinate of the location

        Returns:
            tuple (int, int): The x and y offsets of the raster pixel containing the location. Returns (-1, -1) if the
            point is not in the raster.
        """
        xsize, ysize = self.resolution
        xoff = int((x - self.xorigin) / self.pixel_width)  # Find the pixel containing the point
        yoff = int((y - self.yorigin) / self.pixel_height)
        if xoff < 0 or xoff >= xsize or yoff < 0 or yoff >= ysize:
            return -1, -1  # Out of bounds
        return xoff, yoff

    def get_raster_values(self, xoff=0, yoff=0, xsize=None, ysize=None, resample_alg='nearest_neighbor',
                          search_valid_value=False):
        """Gets all the raster values.

        Notes:
            By default, reads the entire raster into memory.

        Args:
            xoff (int): Starting pixel in x-direction
            yoff (int): Starting pixel in y-direction
            xsize (int): Number of pixels to read in the x-direction
            ysize (int): Number of pixels to read in the y-direction
            resample_alg (str): The resampling algorithm to use. Must be one of the keys in GDAL_RESAMPLE_ALGORITHMS.
            search_valid_value (bool): Whether to search in each direction until a valid value is found

        Returns:
            (np.ndarray): The array containing all the raster values
        """
        resample_alg = self._validate_resample_algorithm(resample_alg)
        ret_val = self._raster.GetRasterBand(1).ReadAsArray(xoff=xoff, yoff=yoff, win_xsize=xsize, win_ysize=ysize,
                                                            resample_alg=resample_alg)
        if search_valid_value and ret_val == self.nodata_value and xsize == 1 and ysize == 1:
            inc = 1
            while ret_val == self.nodata_value:
                # Search in each direction to get a good value
                cur_x = xoff + inc
                cur_y = yoff
                if cur_x < self.resolution[0]:
                    ret_val = self._raster.GetRasterBand(1).ReadAsArray(xoff=cur_x, yoff=cur_y, win_xsize=xsize,
                                                                        win_ysize=ysize, resample_alg=resample_alg)
                if ret_val != self.nodata_value:
                    break
                cur_x = xoff
                cur_y = yoff + inc
                if cur_y < self.resolution[1]:
                    ret_val = self._raster.GetRasterBand(1).ReadAsArray(xoff=cur_x, yoff=cur_y, win_xsize=xsize,
                                                                        win_ysize=ysize, resample_alg=resample_alg)
                if ret_val != self.nodata_value:
                    break
                cur_x = xoff - inc
                cur_y = yoff
                if cur_x >= 0:
                    ret_val = self._raster.GetRasterBand(1).ReadAsArray(xoff=cur_x, yoff=cur_y, win_xsize=xsize,
                                                                        win_ysize=ysize, resample_alg=resample_alg)
                if ret_val != self.nodata_value:
                    break
                cur_x = xoff
                cur_y = yoff - inc
                if cur_y >= 0:
                    ret_val = self._raster.GetRasterBand(1).ReadAsArray(xoff=cur_x, yoff=cur_y, win_xsize=xsize,
                                                                        win_ysize=ysize, resample_alg=resample_alg)
                if ret_val != self.nodata_value:
                    break
                inc += 1
        return ret_val

    def get_raster_value_at_loc(self, x, y, interpolate=True):
        """Get raster value using bilinear interpolation. Follows GdalRasterInputBaseImpl::GetLocationElevation.

        Args:
            x (float): X-coordinate of the location
            y (float): Y-coordinate of the location
            interpolate (bool): Whether to use bilinear interpolation or just get the value at the given location

        Returns:
            (float): The elevation at the x, y location.
        """
        transform = self.geotransform

        nx, ny = self.resolution

        # half raster cell widths
        hx = transform[1] / 2.0
        hy = transform[5] / 2.0

        # calculate raster lower bound indices from point
        fx = (x - (transform[0] + hx)) / transform[1]
        fy = (y - (transform[3] + hy)) / transform[5]

        ix1 = int(np.floor(fx))
        iy1 = int(np.floor(fy))

        # special case where point is on upper bounds
        if fx == (nx - 1):
            ix1 -= 1
        if fy == (ny - 1):
            iy1 -= 1

        # upper bound indices on raster
        ix2 = ix1 + 1
        iy2 = iy1 + 1

        # Test array bounds to ensure point is within raster midpoints
        if ix1 < 0 or iy1 < 0 or ix2 > (nx - 1) or iy2 > (ny - 1):
            return self.nodata_value

        if not interpolate:
            return self.get_raster_values(ix1, iy1, 1, 1)[0][0]

        # calculate differences from point to bounding raster midpoints
        dx1 = x - (transform[0] + ix1 * transform[1] + hx)
        dy1 = y - (transform[3] + iy1 * transform[5] + hy)
        dx2 = (transform[0] + ix2 * transform[1] + hx) - x
        dy2 = (transform[3] + iy2 * transform[5] + hy) - y

        # use the differences to weigh the four raster values
        rows = [iy1, iy1, iy2, iy2]
        cols = [ix1, ix2, ix1, ix2]
        div = transform[1] * transform[5]

        e = []
        for i in range(4):
            e.append(self.get_raster_values(xoff=cols[i], yoff=rows[i], xsize=1, ysize=1)[0][0])
            if e[i] == self.nodata_value:
                return self.nodata_value

        b_elev = e[0] * dx2 * dy2 / div + e[1] * dx1 * dy2 / div + e[2] * dx2 * dy1 / div + e[3] * dx1 * dy1 / div

        return b_elev

    def clear_raster(self):
        """Clears the GDAL raster dataset."""
        self._raster = None
