"""TrimRasterTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import math
import os

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.utilities import gdal_wrappers as gw
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.coverage import PolygonFromRasterBounds
from xms.tool.utilities.coverage_conversion import get_polygons_from_coverage, polygons_to_shapefile
from xms.tool.whitebox import WhiteboxToolRunner
import xms.tool.whitebox.whitebox_tool_runner as wbt

ARG_INPUT_RASTER = 0
ARG_INPUT_COVERAGE = 1
ARG_OUTPUT_RASTER = 2


class TrimRasterTool(Tool):
    """Tool to trim a raster via a coverage with polygons, writing out a new raster."""

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Trim Raster')
        self._input_raster = None
        self._input_coverage = None
        self._all_inside_raster = True
        self._all_outside_raster = True
        self.vector_output = []

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='input_raster', description='Input raster'),
            self.coverage_argument(name='input_coverage', description='Coverage with polygons'),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def _get_coverage_location_info(self):
        """Called to determine if the input_coverage is outside the boundary of input_raster.

        Returns:
            (bool): True if the input_coverage is located outside the input_raster boundary.
        """
        locations = []
        for geom in self._input_coverage.geometry:
            if geom is not None and geom.geom_type == "Polygon":
                # Exterior
                locations.extend(list(geom.exterior.coords))
        min_pt, max_pt = self._input_raster.get_raster_bounds()
        tol = 1.0e-9
        tol_min = [abs(min_pt[0]) * tol, abs(min_pt[1]) * tol]
        tol_max = [abs(max_pt[0]) * tol, abs(max_pt[1]) * tol]
        locations = gu.transform_points_from_wkt(locations, self.default_wkt, self._input_raster.wkt)
        for pt in locations:
            if (min_pt[0] - tol_min[0] <= pt[0] <= max_pt[0] + tol_max[0]  # noqa W503
                    and min_pt[1] - tol_min[1] <= pt[1] <= max_pt[1] + tol_max[1]):  # noqa W503
                self._all_outside_raster = False
            else:
                self._all_inside_raster = False

    def validate_arguments(self, arguments):
        """Called to determine if arguments are valid.

        Args:
            arguments (list): The tool arguments.

        Returns:
            (dict): Dictionary of errors for arguments.
        """
        errors = {}
        # Validate the input grid
        self._input_raster = self.get_input_raster(arguments[ARG_INPUT_RASTER].value)
        self._input_coverage = self.get_input_coverage(arguments[ARG_INPUT_COVERAGE].value)
        if self.default_wkt and gu.is_local(self.default_wkt):
            errors[arguments[ARG_INPUT_COVERAGE].name] = 'Please define a display projection.'
        if not self._input_raster.wkt or gu.is_local(self._input_raster.wkt):
            errors[arguments[ARG_INPUT_RASTER].name] = 'Please select a raster that has a projection.'
        if errors:
            return errors
        if self._input_coverage is not None and self._input_raster is not None:
            self._get_coverage_location_info()
        if self._input_coverage is None:
            errors[arguments[ARG_INPUT_COVERAGE].name] = 'Please select an existing coverage file.'
        elif not (self._input_coverage['geometry_types'] == 'Polygon').any():
            errors[arguments[ARG_INPUT_COVERAGE].name] = 'Please select a coverage with polygons for trimming.'
        elif self._input_raster and self._input_coverage is not None and self._all_outside_raster:
            errors[arguments[ARG_INPUT_COVERAGE].name] = 'Cannot trim the raster because the coverage is located ' \
                                                         'outside the raster area.'
        return errors

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        out_path = self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value)

        if self._create_trimmed_raster(out_path, arguments[ARG_INPUT_COVERAGE].value):
            self.set_output_raster_file(out_path, arguments[ARG_OUTPUT_RASTER].text_value)

    def _create_trimmed_raster(self, out_path, trim_cov_name):
        """
        Creates a trimmed raster using GDAL Warp, and the coverage polygons as a cutline.

        Args:
            out_path (str): The path of the raster to create.
            trim_cov_name (str): The name of the coverage to be used for trimming

        Returns:
            bool:  True if success, else False
        """
        # Make a vector layer containing the polygon data
        display_sr, raster_sr, vertical_multiplier = ru.reconcile_projection_to_raster(self._input_raster,
                                                                                       self.default_wkt,
                                                                                       self.vertical_datum,
                                                                                       self.vertical_units)
        cutline_file = f'{filesystem.temp_filename()}.shp'
        if not self._all_inside_raster:
            # Create a boundary polygon for the raster data
            temp_file = filesystem.temp_filename()
            boundary_cov_name = os.path.basename(temp_file)
            polygon_creator = PolygonFromRasterBounds(self._input_raster, boundary_cov_name, self.default_wkt,
                                                      self.vertical_datum,
                                                      self.vertical_units, self.logger)
            boundary_shapefile = f'{temp_file}.shp'
            boundary_cov = polygon_creator.create_boundary_coverage()
            polygons_to_shapefile(get_polygons_from_coverage(boundary_cov), boundary_shapefile, self.default_wkt)

            new_trim_cov_name = os.path.basename(filesystem.temp_filename())
            # Clip the trim coverage to the boundary polygon
            arg_values = {
                'input_vector_file': trim_cov_name,
                'input_clip_polygon_vector_file': boundary_cov_name,
                'output_vector_file': new_trim_cov_name
            }
            wbt_runner = WhiteboxToolRunner(self)
            wbt.run_wbt_tool(wbt_runner, 'Clip', arg_values, False)
            cov = self.vector_output[0]
        else:
            cov = self._input_coverage
        polygons = get_polygons_from_coverage(cov)
        num_features = len(polygons[1])
        polygons_to_shapefile(polygons, cutline_file, display_sr.ExportToWkt())

        # Set up the cutline layer
        # cutline_file = '/vsimem/cutline.shp'
        # cutline_file = 'c:\\temp\\TrimRaster\\cutline.shp'
        # polygons_to_shapefile(polygons, cutline_file, display_sr.ExportToWkt())

        # Add some info to the output window
        plural = 's' if num_features > 1 else ''
        self.logger.info(f'Using {num_features} polygon{plural} for clipping...')

        # Check if we need to reformat the raster to have a NODATA value
        warp_input = self._input_raster.gdal_raster
        acting_nodata = self._input_raster.get_acting_nodata_value()
        if acting_nodata[1]:
            warp_input = gw.gdal_translate_get_dataset(self._input_raster.gdal_raster, '/vsimem/raster_with_nodata.tif',
                                                       noData=acting_nodata[0])

        # Call GDAL Warp with the vector layer of polygons as a cutline layer
        trimmed_file = filesystem.temp_filename(suffix='.tif')
        convert_vertical = not math.isclose(vertical_multiplier, 1.0)
        if not convert_vertical:
            trimmed_file = out_path
        data_type, nodata = ru.get_datatype_and_nodata_for_warp(self._input_raster)
        rval = gw.gdal_warp(warp_input, trimmed_file, cutlineDSName=cutline_file, cropToCutline=True,
                            dstNodata=nodata, srcSRS=raster_sr, dstSRS=display_sr,
                            outputType=data_type, options=['-overwrite', '-co', 'compress=LZW'])
        gu.delete_vector_file(cutline_file)

        # Set the units in the band of the output raster (out_path)
        if rval:
            ru.set_vertical_units(trimmed_file, self.vertical_datum, self.vertical_units, vertical_multiplier, out_path)

        return rval
