"""Helper class to run GDAL utility scripts and executables."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"

# 1. Standard Python modules
import os
import subprocess
import sys
from typing import Optional, Sequence, Union

# 2. Third party modules
from osgeo import gdal

# 3. Aquaveo modules
from xms.core.filesystem import filesystem

# 4. Local modules
from xms.gdal.utilities.gdal_wrappers import gdal_translate, gdal_warp
from xms.gdal.utilities.test_utils import linux_not_supported


class GdalRunner:
    """Class that runs GDAL commands and scripts."""

    def __init__(self):
        """Initializes the class."""
        self.temp_file_path = filesystem.temp_filename()
        os.mkdir(self.temp_file_path)
        self._file_count = 0

    def run(self, util_name: str, out_file: str, arguments: Sequence[str], increment: bool = True) -> str:
        """
        Run a GDAL command or Python script.

        Args:
            util_name: The name of the command or Python script.
            out_file: The name of the output file. A prefix number is placed at the beginning.
            arguments: The command arguments.
            increment: Should the prefix number of the output file be incremented. Use False for in-place edits.

        Returns:
            The output file or None if no "$OUT_FILE$" exists in the arguments.
        """
        args = []
        if util_name.endswith('.py'):
            import osgeo_utils
            gdal_script_path = osgeo_utils.__path__[0]
            args.append(sys.executable)
            args.append(os.path.join(gdal_script_path, util_name))
        else:
            if linux_not_supported():
                # The GDAL utilities should be in the path on linux.
                args.append(util_name)
            else:
                import osgeo
                gdal_exe_path = osgeo.__path__[0]
                args.append(os.path.join(gdal_exe_path, util_name))

        out_file_path = None
        for arg in arguments:
            if arg.find('$OUT_FILE$') >= 0:
                out_file_path = self.get_temp_file(out_file, increment)
                arg = arg.replace('$OUT_FILE$', out_file_path)
            args.append(arg)

        return_code, output = shell_command(args, shell=False)
        if return_code != 0:
            error = f'GDAL command failed ({return_code}): {output}'
            raise RuntimeError(error)
        return out_file_path

    def run_wrapper(self, script_name: str,
                    source: Union[str, list[str], None] = None,
                    destination: Optional[str] = None,
                    options: Optional[list[str]] = None,
                    increment: bool = True) -> Optional[str]:
        """
        Runs a GDAL command line wrapper via Python.

        Args:
            script_name: The gdal command line utility to run.
            source: The path to the input file or files.
            destination: The name pattern for the output file.
            options: Additional command line arguments. Defaults to None.
            increment: Should the prefix number of the output file be incremented. Use False for in-place edits.
        """
        gdal_func = None
        match script_name:
            case "gdalwarp":
                gdal_func = gdal_warp
            case "gdal_translate":
                gdal_func = gdal_translate
        if gdal_func is not None:
            gdal.UseExceptions()
            output_file = self.get_temp_file(destination, increment)
            gdal_func(source, output_file, options=options)
            return output_file
        raise NotImplementedError(f'GDAL command {script_name} is not supported by GdalRunner')

    def get_temp_file(self, out_file: str, increment: bool = True):
        """
        Get a temporary file path for tool output.

        Args:
            out_file: The name of the output file. A prefix number is placed at the beginning.
            increment: Should the prefix number incremented. Use False for in-place edits.

        Returns:
            (str): Output file path.
        """
        if increment:
            self._file_count += 1
        out_file = f'{self._file_count}-{out_file}'
        out_file_path = f'{os.path.join(self.temp_file_path, out_file)}'
        return out_file_path


def shell_command(command: Union[str, Sequence[str]], shell: bool = True) -> tuple[int, str]:
    """Run a command.

    Args:
        command: The command string.
        shell: Should command be run in the shell?

    Returns:
        The return code [0] and the output [1].
    """
    return_code = 0
    try:
        output = subprocess.check_output(command, stderr=subprocess.STDOUT, shell=shell)
        output = output.decode('utf-8')
    except subprocess.CalledProcessError as err:
        return_code = err.returncode
        output = f'command: {",".join(command)}\n'
        output += err.output.decode('utf-8')
    return return_code, output
