diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index 582d17116e..dae13df723 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -23,7 +23,13 @@ resolve_latest_ckpt_step, validate_output_dir, ) -from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process, set_proc_title +from prime_rl.utils.process import ( + cleanup_processes, + cleanup_threads, + monitor_process, + set_proc_title, + start_tail_processes, +) RL_TOML = "rl.toml" RL_SBATCH = "rl.sbatch" @@ -321,11 +327,8 @@ def sigterm_handler(signum, frame): # Monitor all processes for failures logger.success("Startup complete. Showing orchestrator logs...") - tail_process = Popen( - f"tail -F '{log_dir / 'orchestrator.log'}'", - shell=True, - ) - processes.append(tail_process) + tail_processes = start_tail_processes(log_dir / "orchestrator.log") + processes.extend(tail_processes) # Check for errors from monitor threads while not (stop_events["orchestrator"].is_set() and stop_events["trainer"].is_set()): diff --git a/src/prime_rl/entrypoints/sft.py b/src/prime_rl/entrypoints/sft.py index 1429e5942c..3e3fb0fa7e 100644 --- a/src/prime_rl/entrypoints/sft.py +++ b/src/prime_rl/entrypoints/sft.py @@ -12,7 +12,13 @@ from prime_rl.utils.config import cli from prime_rl.utils.logger import setup_logger from prime_rl.utils.pathing import format_log_message, get_config_dir, get_log_dir, validate_output_dir -from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process, set_proc_title +from prime_rl.utils.process import ( + cleanup_processes, + cleanup_threads, + monitor_process, + set_proc_title, + start_tail_processes, +) SFT_TOML = "sft.toml" SFT_SBATCH = "sft.sbatch" @@ -161,11 +167,11 @@ def sft_local(config: SFTConfig): monitor_threads.append(monitor_thread) logger.success("Startup complete. Showing trainer logs...") - tail_process = Popen( - f"tail -F '{log_dir / 'trainer.log'}' | sed -u 's/^\\[[a-zA-Z]*[0-9]*\\]://'", - shell=True, + tail_processes = start_tail_processes( + log_dir / "trainer.log", + strip_torchrun_prefix=True, ) - processes.append(tail_process) + processes.extend(tail_processes) stop_event.wait() diff --git a/src/prime_rl/utils/process.py b/src/prime_rl/utils/process.py index ab14813f62..3329454b85 100644 --- a/src/prime_rl/utils/process.py +++ b/src/prime_rl/utils/process.py @@ -11,6 +11,7 @@ from prime_rl.utils.logger import get_logger PRIME_RL_PROC_PREFIX = "PRIME-RL" +TORCHRUN_LOG_PREFIX_PATTERN = r"s/^\[[a-zA-Z]*[0-9]*\]://" def set_proc_title(name: str) -> None: @@ -62,6 +63,30 @@ def cleanup_processes(processes: list[Popen]): get_logger().debug(f"Cleaned up process {process.pid}") +def start_tail_processes( + log_path: str | os.PathLike[str], + *, + strip_torchrun_prefix: bool = False, +) -> list[Popen]: + """Start subprocesses that stream a log file without invoking a shell.""" + tail_process = Popen( + ["tail", "-F", os.fspath(log_path)], + stdout=subprocess.PIPE if strip_torchrun_prefix else None, + ) + if not strip_torchrun_prefix: + return [tail_process] + + if tail_process.stdout is None: + raise RuntimeError("tail process stdout was not captured") + + sed_process = Popen( + ["sed", "-u", TORCHRUN_LOG_PREFIX_PATTERN], + stdin=tail_process.stdout, + ) + tail_process.stdout.close() + return [tail_process, sed_process] + + def monitor_process(process: Popen, stop_event: Event, error_queue: list, process_name: str): """Monitor a subprocess and signal errors via shared queue.""" process.wait() diff --git a/tests/unit/utils/test_process.py b/tests/unit/utils/test_process.py new file mode 100644 index 0000000000..d5875b1c9d --- /dev/null +++ b/tests/unit/utils/test_process.py @@ -0,0 +1,61 @@ +import subprocess + +from prime_rl.utils import process + + +class FakeStdout: + def __init__(self) -> None: + self.closed = False + + def close(self) -> None: + self.closed = True + + +class FakePopen: + def __init__(self, cmd, **kwargs) -> None: + self.cmd = cmd + self.kwargs = kwargs + self.stdout = FakeStdout() if kwargs.get("stdout") is subprocess.PIPE else None + + +def test_start_tail_processes_uses_argv_without_shell(tmp_path, monkeypatch): + calls = [] + + def fake_popen(cmd, **kwargs): + popen = FakePopen(cmd, **kwargs) + calls.append(popen) + return popen + + monkeypatch.setattr(process, "Popen", fake_popen) + + log_path = tmp_path / "trainer '$(touch injected)'.log" + + started = process.start_tail_processes(log_path) + + assert started == calls[:1] + assert calls[0].cmd == ["tail", "-F", str(log_path)] + assert "shell" not in calls[0].kwargs + + +def test_start_tail_processes_can_strip_torchrun_prefix_without_shell(tmp_path, monkeypatch): + calls = [] + + def fake_popen(cmd, **kwargs): + popen = FakePopen(cmd, **kwargs) + calls.append(popen) + return popen + + monkeypatch.setattr(process, "Popen", fake_popen) + + log_path = tmp_path / "trainer.log" + + started = process.start_tail_processes(log_path, strip_torchrun_prefix=True) + + assert started == calls + assert calls[0].cmd == ["tail", "-F", str(log_path)] + assert calls[0].kwargs == {"stdout": subprocess.PIPE} + assert calls[0].stdout is not None + assert calls[0].stdout.closed is True + assert calls[1].cmd == ["sed", "-u", process.TORCHRUN_LOG_PREFIX_PATTERN] + assert calls[1].kwargs == {"stdin": calls[0].stdout} + assert "shell" not in calls[1].kwargs