"""Class for reading rasters using GDAL."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os

# 2. Third party modules
import osgeo.gdal as gdal

# 3. Aquaveo modules
from xms.core.filesystem import filesystem

# 4. Local modules
from xms.gdal.rasters import RasterInput
import xms.gdal.utilities.gdal_utils as gu


def get_datatype_and_nodata(input_raster: RasterInput) -> tuple[int, float]:
    """Gets the datatype and nodata value from the RasterInput for use in gdal.Warp.

    Args:
        input_raster (RasterInput): The raster.

    Returns:
        tuple: The data type and nodata values to use with gdal.Warp.
    """
    data_type = input_raster.data_type
    if data_type == gdal.GDT_Float32 or data_type == gdal.GDT_Float64:
        data_type = gdal.GDT_Float32
    else:
        data_type = gdal.GDT_Unknown
    # Make sure we have a valid NODATA value for this raster
    nodata = input_raster.nodata_value
    if nodata is None:
        nodata = -32767.0
    return data_type, nodata


class RasterReproject:
    """Class for reprojecting rasters."""

    GRA_NearestNeighbour = gdal.GRA_NearestNeighbour
    GRA_Bilinear = gdal.GRA_Bilinear
    GRA_Cubic = gdal.GRA_Cubic
    GRA_CubicSpline = gdal.GRA_CubicSpline
    GRA_Lanczos = gdal.GRA_Lanczos

    def __init__(self, source_files: list, output_file: str, wkt_or_prj_file: str, raster_wkt: str = None,
                 resample_alg: int = gdal.GRA_NearestNeighbour):
        """Initializes the class.

        Args:
            source_files (list[str]): A list of source filenames
            output_file (str): Output filename
            wkt_or_prj_file (str): Projection well known text or PROJ.4 of the output file or a file containing WKT or
                PROJ4 strings.
            raster_wkt (str): Projection well known text or PROJ.4 of the raster to reproject.
            resample_alg (int): The resample algorithm to use, one of gdal.GRA_*
        """
        gdal.SetConfigOption('GTIFF_REPORT_COMPD_CS', 'TRUE')
        self.source_files = source_files
        self.output_file = output_file
        if os.path.isfile(wkt_or_prj_file):
            proj = gu.read_projection_file(wkt_or_prj_file)
        else:
            proj = wkt_or_prj_file
        self.output_wkt = proj
        self.raster_wkt = raster_wkt
        ri = RasterInput(source_files[0])
        self._datatype, self._nodata = get_datatype_and_nodata(ri)
        self._resample_alg = resample_alg

    def run(self):
        """Runs the reproject command.

        Returns:
            gdal.Dataset: The GDAL dataset resulting from changing the projection
        """
        # Check if self.output_wkt is valid
        gdal.UseExceptions()
        if not gu.valid_wkt(self.output_wkt):
            raise ValueError(f'Unable to load WKT: {self.output_wkt}')
        if self.raster_wkt is not None and not gu.valid_wkt(self.raster_wkt):
            raise ValueError(f'Unable to load WKT: {self.raster_wkt}')
        _, wkt = gu.remove_epsg_code_if_unit_mismatch(self.output_wkt)
        wkt, unit_type = gu.fix_wkt_and_get_unit_type(wkt, False)
        source_srs_file = None
        target_srs_file = filesystem.temp_filename(suffix='.prj')
        wkt = wkt.replace(',TOWGS84[0,0,0,0,0,0,0]', '')
        with open(target_srs_file, 'w') as prj_file:
            prj_file.write(wkt)
        if not gu.is_local(wkt):
            # Remove TOWGS84 strings from projections
            source_srs_file = filesystem.temp_filename(suffix='.prj')
            if self.raster_wkt is not None:
                self.raster_wkt = self.raster_wkt.replace(',TOWGS84[0,0,0,0,0,0,0]', '')
                with open(source_srs_file, 'w') as prj_file:
                    prj_file.write(self.raster_wkt)
            options = '-overwrite -co compress=LZW -co BIGTIFF=YES'
            if self.raster_wkt is not None:
                if not gu.is_local(self.raster_wkt):
                    options += f' -s_srs {source_srs_file}'
                elif gu.is_local(self.raster_wkt):
                    options += f' -s_srs {target_srs_file}'
            options += f' -t_srs {target_srs_file} -dstnodata {self._nodata}'
            if self._datatype == gdal.GDT_Float32:
                options += ' -ot Float32'
            raster_file = gdal.Warp(self.output_file, self.source_files, resampleAlg=self._resample_alg,
                                    options=options)
        else:
            options = '-co compress=LZW -co BIGTIFF=YES'
            # if gu.is_local(self.raster_wkt) and not gu.is_local(wkt):
            #     options += f'-a_srs {target_srs_file}'
            raster_file = gdal.Translate(self.output_file, self.source_files[0], options=options)
        raster_band = raster_file.GetRasterBand(1)
        if unit_type != '':
            raster_band.SetUnitType(unit_type)
        raster_band.FlushCache()
        raster_band = None  # noqa: F841
        # Clean up files
        if source_srs_file and os.path.exists(source_srs_file):
            os.remove(source_srs_file)
        if os.path.exists(target_srs_file):
            os.remove(target_srs_file)
        return raster_file
