"""MergeElevationRastersTool class."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import shutil

# 2. Third party modules

# 3. Aquaveo modules
from xms.core.filesystem import filesystem
from xms.gdal.rasters import raster_utils as ru
from xms.gdal.utilities import gdal_utils as gu, GdalRunner
from xms.tool_core import IoDirection, Tool

# 4. Local modules

ARG_RASTER_TO_CLIP = 0
ARG_ELEVATION_RASTER = 1
ARG_CLIP_OPTION = 2
ARG_OUTPUT_RASTER = 3


class ClipRastersFromElevationsTool(Tool):
    """Tool to clip a raster based on an elevation raster."""
    CLIP_ELEVATIONS_ABOVE = 'Clip elevations above'
    CLIP_ELEVATIONS_BELOW = 'Clip elevations below'

    def __init__(self):
        """Initializes the class."""
        super().__init__(name='Clip Raster from Elevations')
        self._file_count = 0

    def initial_arguments(self):
        """Get initial arguments for tool.

        Must override.

        Returns:
            (list): A list of the initial tool arguments.
        """
        arguments = [
            self.raster_argument(name='raster_to_clip', description='Raster to clip'),
            self.raster_argument(name='elevation_raster', description='Elevation raster'),
            self.string_argument(name='clip_option', description='Clip elevations above or below elevation raster',
                                 choices=[self.CLIP_ELEVATIONS_ABOVE, self.CLIP_ELEVATIONS_BELOW]),
            self.raster_argument(name='output_raster', description='Output raster', io_direction=IoDirection.OUTPUT)
        ]
        return arguments

    def run(self, arguments):
        """Override to run the tool.

        Args:
            arguments (list): The tool arguments.
        """
        # make a temporary directory
        self.temp_file_path = filesystem.temp_filename()
        os.mkdir(self.temp_file_path)

        raster_to_clip = self.get_input_raster(arguments[ARG_RASTER_TO_CLIP].value)
        raster_to_clip_file = self.get_input_raster_file(arguments[ARG_RASTER_TO_CLIP].value)
        elevation_raster = self.get_input_raster(arguments[ARG_ELEVATION_RASTER].value)
        elevation_raster_file = self.get_input_raster_file(arguments[ARG_ELEVATION_RASTER].value)
        clip_option = arguments[ARG_CLIP_OPTION].value

        wkt = raster_to_clip.wkt

        # build projection file
        wkt_file = None
        if gu.valid_wkt(wkt):
            wkt_file = os.path.join(self.temp_file_path, 'output.prj')
            with open(wkt_file, 'wt') as file:
                file.write(wkt)

        gdal_runner = GdalRunner()

        # warp elevation file to clipper file (same pixels and bounds as raster_to_clip)
        min, max = raster_to_clip.get_raster_bounds()
        xmin = min[0]
        ymin = min[1]
        xmax = max[0]
        ymax = max[1]
        xsize, ysize = raster_to_clip.resolution
        args = ['-te', f'{xmin}', f'{ymin}', f'{xmax}', f'{ymax}', '-ts', f'{xsize}', f'{ysize}', '-r', 'bilinear']
        if wkt_file is not None:
            args.extend(['-t_srs', f'{wkt_file}'])
        args.extend(['-of', 'GTiff'])
        clipper_file_no_data = gdal_runner.run_wrapper('gdalwarp', elevation_raster_file, 'clipper_no_data.tif', args)

        # remove no_data from clipper file
        args = ['-a_nodata', 'none', '-of', 'GTiff']
        clipper_file = gdal_runner.run_wrapper('gdal_translate', clipper_file_no_data, 'clipper.tif', args)

        # calc to generate clipped output
        to_clip_nodata_value = raster_to_clip.nodata_value
        if not to_clip_nodata_value:
            nodata = raster_to_clip.get_acting_nodata_value()
            if nodata[1]:
                to_clip_nodata_value = nodata[0]
        elevation_nodata_value = elevation_raster.nodata_value
        if not elevation_nodata_value:
            nodata = elevation_raster.get_acting_nodata_value()
            if nodata[1]:
                elevation_nodata_value = nodata[0]
        compare = '<=' if clip_option == self.CLIP_ELEVATIONS_ABOVE else '>='
        not_compare = '>' if clip_option == self.CLIP_ELEVATIONS_ABOVE else '<'
        if not to_clip_nodata_value:
            to_clip_nodata_value = -999999.0
        if not elevation_nodata_value:
            elevation_nodata_value = -999999.0
        # A - raster_to_clip_file, B - clipper_file
        # no data in A - should still have no data - this gets handled by default
        # NOTE: Whenever I use a logical multiplaction, it's like doing an "if" statement in the below statements
        # no data in B - should have A - A*(B=={elevation_nodata_value})
        # 1. otherwise A if compare else no data A*(B!={elevation_nodata_value})*(A{compare}B) +
        #               {to_clip_nodata_value}*(B!={elevation_nodata_value})*(A{not_compare}B)
        # 2. A*(B=={elevation_nodata_value})
        expression = (
            f'--calc=A*logical_and(A{compare}B,B!={elevation_nodata_value})'
            f'+{to_clip_nodata_value}*(B!={elevation_nodata_value})*(A{not_compare}B)'
            f'+A*(B=={elevation_nodata_value})'
        )
        args = ['-A', raster_to_clip_file, '-B', clipper_file, '--outfile=$OUT_FILE$', expression,
                f'--NoDataValue={to_clip_nodata_value}']
        output_file = gdal_runner.run('gdal_calc.py', 'inverted.tif', args).strip('"')
        output_file = ru.reproject_raster(output_file, self.get_output_raster(arguments[ARG_OUTPUT_RASTER].value),
                                          self.default_wkt, self.vertical_datum, self.vertical_units)
        self.set_output_raster_file(output_file, arguments[ARG_OUTPUT_RASTER].text_value)

        # Remove the folder containing any generated files
        shutil.rmtree(gdal_runner.temp_file_path)

# def main():
#     """Main function, for testing."""
#     pass
#     from xms.tool_gui.tool_dialog import ToolDialog
#     from xms.guipy.dialogs.xms_parent_dlg import ensure_qapplication_exists
#     from xms.tool.utilities.file_utils import get_test_files_path
#
#     qapp = ensure_qapplication_exists()
#     tool = ClipRastersFromElevationsTool()
#     tool.set_gui_data_folder(get_test_files_path())
#     arguments = tool.initial_arguments()
#     tool_dialog = ToolDialog(None, arguments, tool.name, tool=tool)
#     if tool_dialog.exec():
#         tool.run_tool(tool_dialog.tool_arguments)
#     qapp = None
#
#
# if __name__ == "__main__":
#     main()
