"""ExtendRasterTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"


# 1. Standard Python modules
import os

# 2. Third party modules
from geopandas import GeoDataFrame

# 3. Aquaveo modules
from xms.core.filesystem import filesystem
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.rasters import RasterInput
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import gdal_wrappers as gw
from xms.gdal.vectors import VectorOutput
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.utilities.coverage_conversion import get_polygons_from_coverage, polygons_to_shapefile

ARG_INPUT_RASTER = 0
ARG_INPUT_COVERAGE = 1
ARG_USE_POLY_ELEVATIONS = 2
ARG_INPUT_MAX_DISTANCE = 3
ARG_OUTPUT_RASTER = 4


class ExtendRasterTool(Tool):
    """Tool to extend a raster to the extent of polygons in a coverage, writing out a new raster."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Extend Raster')

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='input_raster', description='Input raster'),
            self.coverage_argument(name='input_coverage', description='Coverage with polygons'),
            self.bool_argument(name='use_poly_elevations',
                               description='Use polygon elevations when extending raster.'),
            self.integer_argument(name='max_distance', description='Maximum distance (in pixels) to interpolate',
                                  min_value=1, value=100),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Make sure there are polygons in the coverage
        input_coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].value)
        if not (input_coverage['geometry_types'] == 'Polygon').any():
            errors[arguments[ARG_INPUT_COVERAGE].name] = (f'The "{arguments[ARG_INPUT_COVERAGE].value}" coverage does '
                                                          'not have any polygons. Please create and build polygons in '
                                                          f'this coverage before running the {self.name} tool.')
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        input_raster = self.get_input_raster(arguments[ARG_INPUT_RASTER].value)
        input_coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].value)
        use_poly_elevations = arguments[ARG_USE_POLY_ELEVATIONS].value
        max_distance = int(arguments[ARG_INPUT_MAX_DISTANCE].value)
        out_path = self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value)

        if self._create_extended_raster(input_raster, input_coverage, use_poly_elevations, max_distance, out_path):
            out_path = ru.reproject_raster(out_path, self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value),
                                           self.default_wkt, self.vertical_datum, self.vertical_units)
            self.set_output_raster_file(out_path, arguments[ARG_OUTPUT_RASTER].text_value)

    def _create_extended_raster(self, input_raster: RasterInput, coverage: GeoDataFrame, use_poly_elevations: bool,
                                max_distance: int, out_path: str):
        """
        Creates an extended raster using GDAL Warp, and the coverage polygons as a boundary.

        Args:
            input_raster (RasterInput): The input raster.
            coverage (GeoDataFrame): The input coverage.
            use_poly_elevations (bool): Whether to use the polygon elevations to interpolate.
            max_distance (int): The maximum distance in pixels to search out to interpolate.
            out_path (str): The path of the raster to create.

        Returns:
            bool:  True if success, else False
        """
        # Make an in memory vector layer containing the polygon data
        polygons = get_polygons_from_coverage(coverage)
        display_sr, raster_sr, vertical_multiplier = ru.reconcile_projection_to_raster(input_raster,
                                                                                       self.default_wkt,
                                                                                       self.vertical_datum,
                                                                                       self.vertical_units)

        # Set up the cutline layer
        cutline_file = '/vsimem/cutline.shp'
        # cutline_file = 'c:/temp/extendRaster/cutline.shp'
        polygons_to_shapefile(polygons, cutline_file, raster_sr.ExportToWkt(), display_sr.ExportToWkt())
        # Add some info to the output window
        num_features = len(polygons[1])
        plural = 's' if num_features > 1 else ''
        self.logger.info(f'Using {num_features} polygon{plural} for clipping...')

        # Call GDAL Warp with the vector layer of polygons as a cutline layer
        # extents_path = '/vsimem/extents.tif' # This doesn't work if I need to call gdal_edit.py or gdal_calc.py
        extents_path = filesystem.temp_filename(suffix='.tif')
        # extents_path = 'c:/temp/extendRaster/extents.tif'  # Debugging tool
        data_type, nodata = ru.get_datatype_and_nodata_for_warp(input_raster)
        ext_ds = gw.gdal_warp_get_dataset(input_raster.gdal_raster, extents_path, cutlineDSName=cutline_file,
                                          cropToCutline=True, dstNodata=nodata, srcSRS=raster_sr, dstSRS=display_sr,
                                          outputType=data_type, options=['-overwrite', '-co', 'compress=LZW'])
        if use_poly_elevations:
            # Create an arc layer to stamp elevations into the raster
            boundary_file = '/vsimem/boundaryLine.shp'
            # boundary_file = 'c:/temp/extendRaster/boundaryLine.shp'  # Debugging tool
            vo = VectorOutput()
            vo.initialize_file(boundary_file, raster_sr.ExportToWkt(), from_wkt=display_sr.ExportToWkt())
            # Get Coverage Arcs
            for holes_data, poly_data in zip(polygons[0], polygons[1]):
                # Create the outer ring of the polygon
                vo.write_arc(poly_data['poly_pts'])
                # Create any interior rings (if there)
                for hole in holes_data:
                    vo.write_arc(hole['poly_pts'])
            # set arc z values
            # Create an in memory raster, rasterize the vector layer, then get the cells that intersected
            gw.gdal_rasterize_layer(vo.ogr_layer, ext_ds, [1], [0], ["BURN_VALUE_FROM=Z", 'allTouched=True'])
            vo = None
            gu.delete_vector_file(boundary_file)
        ext_ds = None
        ru.set_vertical_units(extents_path, self.vertical_datum, self.vertical_units, vertical_multiplier, extents_path)

        # Fill no data
        ru.fill_nodata(RasterInput(extents_path, True), max_distance)

        rval = gw.gdal_warp(extents_path, out_path, cutlineDSName=cutline_file, cropToCutline=True,
                            dstNodata=nodata, srcSRS=raster_sr, dstSRS=display_sr,
                            outputType=data_type, options=['-overwrite', '-co', 'compress=LZW'])
        if rval:
            ru.set_vertical_units(out_path, self.vertical_datum, self.vertical_units, vertical_multiplier, out_path)
        os.remove(extents_path)
        gu.delete_vector_file(cutline_file)
        return rval
