"""Module for the PTM model runner."""

__copyright__ = "(C) Copyright Aquaveo 2025"
__license__ = "All rights reserved"
__all__ = ['run_ptm']

# 1. Standard Python modules
from contextlib import suppress
import os
import subprocess
import sys
from typing import Optional

# 2. Third party modules

# 3. Aquaveo modules
from xms.api.dmi import XmsEnvironment as XmEnv

# 4. Local modules


def run_ptm(args: Optional[list[str]] = None) -> int:
    """
    Run PTM.

    Args:
        args: The arguments to use. If not provided, sys.argv will be used instead.

    Returns:
        The exit code that PTM (should have*) returned.

        * PTM sometimes returns nonzero when it actually failed. The runner detects that and fixes it.
    """
    args = args or sys.argv
    model_path, control_path = args[1:]
    stdout = XmEnv.xms_environ_stdout_file()

    with open(stdout, 'w') as f:
        process = subprocess.Popen([model_path, control_path], stdout=f, stderr=f, stdin=subprocess.PIPE)
        register_subprocess_with_xms(process.pid)
        with suppress(subprocess.TimeoutExpired):
            process.communicate('\n'.encode(), timeout=0.5)
        process.wait()

    if process.returncode != 0 or stdout_looks_bad(stdout):
        return 1
    else:
        return 0


def register_subprocess_with_xms(pid: int):
    """
    Tell XMS that the runner script launched a subprocess.

    XMS provides a button in the simulation run queue that lets the user abort a simulation in progress. When the
    user cancels the run, XMS knows what process it ran for the model and kills that process to clean it up. But
    that process is this Python script, which is just managing the model. And when XMS kills it, it doesn't get a
    chance to clean up its child processes, so those processes leak. Registering the process with this method allows
    XMS to also kill it too when it kills this script.

    Args:
        pid: The ID of the process to register.
    """
    child_process_file_name = f'{os.getpid()}_child_proc.txt'
    child_process_file_path = os.path.join(XmEnv.xms_environ_temp_directory(), child_process_file_name)
    with open(child_process_file_path, 'a') as f:
        f.write(f'{pid}\n')


def stdout_looks_bad(stdout: str) -> bool:
    for line in open(stdout):
        if line.startswith('ERROR:'):
            return True
    return False
