"""WatershedFromRasterTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os.path

# 2. Third party modules

# 3. Aquaveo modules
from xms.gdal.rasters import RasterInput
from xms.gdal.rasters.raster_utils import convert_index_raster_data_to_polygons
from xms.gdal.utilities import gdal_utils as gu
from xms.gdal.vectors.vector_input import get_poly_features
from xms.tool_core import IoDirection, Tool

# 4. Local modules
from xms.tool.algorithms.coverage import TrimCoverage
from xms.tool.utilities.coverage_conversion import (convert_polygons_to_coverage, get_arcs_from_coverage,
                                                    get_polygons_from_list)
from xms.tool.utilities.file_utils import get_raster_filename
from xms.tool.whitebox import WhiteboxToolRunner
import xms.tool.whitebox.whitebox_tool_runner as wbr

ARG_INPUT_RASTER = 0
ARG_INPUT_OUTLETS = 1
ARG_THRESHOLD_AREA_SQ_MI = 2
ARG_PREPROCESSING_ENGINE = 3
ARG_WATERSHED_BOUNDARIES = 4
ARG_STREAM_LINES = 5


class WatershedFromRasterTool(Tool):
    """WatershedFromRasterTool class."""
    WHITEBOX_RHO8 = 'Whitebox rho8'
    WHITEBOX_FULL_WORKFLOW = 'Whitebox full workflow'

    def __init__(self):
        """Initializes the class."""
        super().__init__('Watershed from Raster')
        self._raster = None
        self.vector_output = []

    def initial_arguments(self):
        """Get initial arguments for tool.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='input_raster', description='Input raster'),
            self.coverage_argument(name='outlet_coverage', description='Outlet coverage'),
            self.float_argument(name='threshold_area_sq_mi', description='Threshold area (sq miles)', value=1.0),
            self.string_argument(name='preprocessing_engine', description='Pre-processing engine',
                                 value=self.WHITEBOX_RHO8, choices=[self.WHITEBOX_RHO8, self.WHITEBOX_FULL_WORKFLOW]),
            self.coverage_argument(name='watershed_boundaries', description='Watershed boundaries',
                                   io_direction=IoDirection.OUTPUT),
            self.coverage_argument(name='stream_lines', description='Stream lines',
                                   io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def run(self, arguments):
        """Runs the tool.

        Args:
            arguments (list): A list of the tool's arguments.
        """
        input_filename = arguments[ARG_INPUT_RASTER].value
        self._raster = self.get_input_raster(input_filename)
        raster_filename = self.get_input_raster_file(input_filename)
        raster_base = os.path.splitext(os.path.basename(raster_filename))[0]
        demfill_filename = f'{raster_base}_breached'
        flowdir_filename = f'{raster_base}_flowdir'
        flowaccum_filename = f'{raster_base}_flowaccum'
        raster_streams_filename = f'{raster_base}_streams'
        vector_streams_filename = f'{raster_base}_streams'
        raster_watershed_filename = f'{raster_base}_watershed'
        vector_watershed_filename = arguments[ARG_WATERSHED_BOUNDARIES].value
        outlet_filename = f'{raster_base}_outlets'
        # Create flow direction and flow accumulation grids.
        wbt_runner = WhiteboxToolRunner(self)
        if arguments[ARG_PREPROCESSING_ENGINE].value == self.WHITEBOX_FULL_WORKFLOW:
            extension = '_full'
            demfill_filename = f'{demfill_filename}{extension}'
            flowdir_filename = f'{flowdir_filename}{extension}'
            flowaccum_filename = f'{flowaccum_filename}{extension}'
            arg_values = {
                'input_dem_file': input_filename,
                'output_dem_file': demfill_filename,
                'output_flow_pointer_file': flowdir_filename,
                'output_flow_accumulation_file': flowaccum_filename,
                'output_type': 'Catchment Area'
            }
            wbr.run_wbt_tool(wbt_runner, 'FlowAccumulationFullWorkflow', arg_values, False)
        else:  # pragma: no cover - test_watershed_from_raster_tool_rho8() is failing
            # Rho8 Method
            extension = '_rho8'
            demfill_filename = f'{demfill_filename}{extension}'
            flowdir_filename = f'{flowdir_filename}{extension}'
            flowaccum_filename = f'{flowaccum_filename}{extension}'
            # maximum_search_distance_cells = int(min(self._raster.resolution) / 40.0)
            maximum_search_distance_cells = 1
            arg_values = {
                'input_dem_file': input_filename,
                'output_file': demfill_filename,
                'maximum_search_distance_cells': maximum_search_distance_cells,
                'minimize_breach_distances': True,
                'fill_unbreached_depressions': True
            }
            wbr.run_wbt_tool(wbt_runner, 'BreachDepressionsLeastCost', arg_values, False)
            arg_values = {
                'input_dem_file': demfill_filename,
                'output_file': flowdir_filename
            }
            wbr.run_wbt_tool(wbt_runner, 'Rho8Pointer', arg_values, False)
            arg_values = {
                'input_dem_or_rho8_pointer_file': flowdir_filename,
                'output_raster_file': flowaccum_filename,
                'output_type': 'catchment area',
                'is_the_input_raster_a_rho8_flow_pointer': True
            }
            wbr.run_wbt_tool(wbt_runner, 'Rho8FlowAccumulation', arg_values, False)
        # Create a raster containing the streams
        accum_threshold_sq_mi = arguments[ARG_THRESHOLD_AREA_SQ_MI].value
        horiz_units = gu.get_horiz_unit_from_wkt(self._raster.wkt)
        sq_meters_per_sq_mi = 2589988.11
        sq_ft_per_sq_mi = 27878400.0
        conversion_factor = sq_meters_per_sq_mi
        if horiz_units == gu.UNITS_FEET_INT or horiz_units == gu.UNITS_FEET_US_SURVEY:  # pragma no cover
            conversion_factor = sq_ft_per_sq_mi
        accum_threshold = accum_threshold_sq_mi * conversion_factor
        raster_streams_filename = f'{raster_streams_filename}{extension}'
        arg_values = {
            'input_d8_flow_accumulation_file': flowaccum_filename,
            'output_file': raster_streams_filename,
            'channelization_threshold': accum_threshold
        }
        wbr.run_wbt_tool(wbt_runner, 'ExtractStreams', arg_values, False)
        # Convert the raster containing the streams to a coverage
        vector_streams_filename = f'{vector_streams_filename}{extension}'
        arg_values = {
            'input_streams_file': raster_streams_filename,
            'input_d8_pointer_file': flowdir_filename,
            'output_file': vector_streams_filename
        }
        wbr.run_wbt_tool(wbt_runner, 'RasterStreamsToVector', arg_values, False)
        # Snap the point(s) in the outlet location coverage to the streams and create a new coverage with the snapped
        # points.
        outlet_filename = f'{outlet_filename}{extension}'
        maximum_snap_distance = max(self._raster.pixel_width * self._raster.resolution[0] / 10.0,
                                    self._raster.pixel_height * self._raster.resolution[1] / 10.0)
        arg_values = {
            'input_pour_points_outlet_file': arguments[ARG_INPUT_OUTLETS].value,
            'input_streams_file': raster_streams_filename,
            'output_file': outlet_filename,
            'maximum_snap_distance_map_units': maximum_snap_distance
        }
        wbr.run_wbt_tool(wbt_runner, 'JensonSnapPourPoints', arg_values, False)
        # Delineate the watershed(s) and create a new raster with the delineated watershed(s).
        raster_watershed_filename = f'{raster_watershed_filename}{extension}'
        arg_values = {
            'input_d8_pointer_file': flowdir_filename,
            'input_pour_points_outlet_file': outlet_filename,
            'output_file': raster_watershed_filename
        }
        wbr.run_wbt_tool(wbt_runner, 'Watershed', arg_values, False)
        # Convert the delineated watershed(s) to a coverage containing polygons.
        raster_watershed_path_and_file = get_raster_filename(raster_watershed_filename)
        watershed_ri = RasterInput(raster_watershed_path_and_file)
        vec_ds = convert_index_raster_data_to_polygons(watershed_ri, None, True)
        watershed_ri = None
        # RasterToVectorPolygons has bugs and does not create polygons for all the isobasins
        # arg_values = {
        #     'input_raster_file': raster_watershed_filename,
        #     'output_polygons_file': vector_watershed_filename
        # }
        # wbr.run_wbt_tool(wbt_runner, 'RasterToVectorPolygons', arg_values, True)
        # Trim the coverage containing the streams to the delineated coverage and create a new coverage.
        # Get polygons
        hole_data, poly_data = get_polygons_from_list(get_poly_features(vec_ds))
        # Flush the buffer for the shapefile
        vec_ds = None
        # Generate watershed coverage or shapefile and clip the streams to the watershed boundaries
        arc_data = get_arcs_from_coverage(self.vector_output[0], self._raster.wkt)
        polys = []
        for poly in poly_data:
            polys.append([poly['poly_pts']])
        new_cov = convert_polygons_to_coverage(polys, vector_watershed_filename, self._raster.wkt)
        self.set_output_coverage(new_cov, arguments[ARG_WATERSHED_BOUNDARIES])
        clipped_cov_name = arguments[ARG_STREAM_LINES].value
        trim_coverage = TrimCoverage(poly_data, hole_data, arc_data, 0.0, False, clipped_cov_name, self.logger, 2500)
        trimmed_streams = trim_coverage.generate_coverage(gu.strip_vertical(self._raster.wkt))
        argument = self.coverage_argument(name='', description='', io_direction=IoDirection.OUTPUT,
                                          value=clipped_cov_name)
        self.set_output_coverage(trimmed_streams, argument)
