"""Class for reading rasters using GDAL."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules

# 2. Third party modules
import numpy as np
import osgeo.gdal as gdal

# 3. Aquaveo modules

# 4. Local modules
from xms.gdal.utilities import gdal_utils as gu


class RasterOutput:
    """Class for writing rasters."""

    def __init__(self, xorigin=0.0, yorigin=0.0, width=0, height=0, pixel_width=0, pixel_height=0,
                 nodata_value=-9999.0, wkt='', data_type=gdal.GDT_Float32, geotransform=None):
        """Initializes the class.

        Args:
            xorigin (float): X-coordinate of the raster's upper left corner
            yorigin (float): Y-coordinate of the raster's upper left corner
            width (int): Total number of pixels in the x-direction
            height (int): Total number of pixels in the y-direction
            pixel_width (float): Width of one pixel
            pixel_height (float): Height of one pixel
            nodata_value (float): No data value for the raster
            wkt (str): Well-known text of the raster's projection
            data_type (int): Type of raster to create
            geotransform (float[6]): Tuple of the raster's Geotransform, if available.  If this is defined, only the
                width and height need to be passed into this class to write the raster.  The xorigin, yorigin,
                pixel_width, and pixel_height are not required.
        """
        gdal.SetConfigOption('GTIFF_REPORT_COMPD_CS', 'TRUE')
        self.xorigin = xorigin
        self.yorigin = yorigin
        self.width = width
        self.height = height
        self.pixel_width = pixel_width
        self.pixel_height = pixel_height
        self.nodata_value = nodata_value
        self.wkt = wkt
        self._type = data_type
        self._geotransform = geotransform

    def _ensure_numpy_array(self, sequence):
        """Ensure that a sequence of values is a numpy array.

        Args:
            sequence (Sequence): The array values

        Returns:
            np.ndarray: The input sequence as a numpy ndarray
        """
        if not isinstance(sequence, np.ndarray):
            sequence = np.array(sequence)
        return sequence

    def write_template_raster(self, filename):
        """Create a raster file with the RasterInput specifications.

        Args:
            filename (str): Path to the output file to write.

        Returns:
            gdal.Dataset: The GDAL dataset.
        """
        # Create a new GeoTiff file.
        gdal.UseExceptions()
        driver = gdal.GetDriverByName('GTiff')
        raster_file = driver.Create(filename, self.width, self.height, 1, self._type)

        # Set the raster dimensions and origin.
        if self._geotransform is not None:
            raster_file.SetGeoTransform(self._geotransform)
        else:
            raster_file.SetGeoTransform((self.xorigin, self.pixel_width, 0, self.yorigin, 0,
                                         abs(self.pixel_height) * -1))

        raster_band = raster_file.GetRasterBand(1)
        if self.wkt is not None and gu.valid_wkt(self.wkt):
            # Write the projection info to the file.
            # See imUpdateRasterCoordSys in the c++ code.
            _, self.wkt = gu.remove_epsg_code_if_unit_mismatch(self.wkt)
            wkt, unit_type = gu.fix_wkt_and_get_unit_type(self.wkt, True)
            raster_file.SetProjection(wkt)
            if unit_type != '':
                raster_band.SetUnitType(unit_type)
        if self.nodata_value is not None:
            raster_band.SetNoDataValue(self.nodata_value)
        raster_band.FlushCache()
        raster_band = None  # noqa: F841
        return raster_file

    def write_raster(self, filename, raster_array):
        """Create the raster file with interpolated elevations.

        Args:
            filename (str): Path to the output file to write
            raster_array (numpy.array): The raster data.
        """
        ds = self.write_template_raster(filename)

        raster_array = self._ensure_numpy_array(raster_array)
        if len(raster_array.shape) == 1:  # Reshape 1D array to a data cube
            raster_array = raster_array.reshape(self.height, self.width)  # Unflatten the array
            raster_array = np.flipud(raster_array)  # Reverse the order of the rows
        ds.GetRasterBand(1).WriteArray(raster_array)

        # Close the file handles.
        ds.FlushCache()
        ds = None  # noqa: F841
