Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/prime_rl/entrypoints/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
resolve_latest_ckpt_step,
validate_output_dir,
)
from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process, set_proc_title
from prime_rl.utils.process import (
cleanup_processes,
cleanup_threads,
monitor_process,
set_proc_title,
start_tail_processes,
)

RL_TOML = "rl.toml"
RL_SBATCH = "rl.sbatch"
Expand Down Expand Up @@ -321,11 +327,8 @@ def sigterm_handler(signum, frame):
# Monitor all processes for failures
logger.success("Startup complete. Showing orchestrator logs...")

tail_process = Popen(
f"tail -F '{log_dir / 'orchestrator.log'}'",
shell=True,
)
processes.append(tail_process)
tail_processes = start_tail_processes(log_dir / "orchestrator.log")
processes.extend(tail_processes)

# Check for errors from monitor threads
while not (stop_events["orchestrator"].is_set() and stop_events["trainer"].is_set()):
Expand Down
16 changes: 11 additions & 5 deletions src/prime_rl/entrypoints/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
from prime_rl.utils.config import cli
from prime_rl.utils.logger import setup_logger
from prime_rl.utils.pathing import format_log_message, get_config_dir, get_log_dir, validate_output_dir
from prime_rl.utils.process import cleanup_processes, cleanup_threads, monitor_process, set_proc_title
from prime_rl.utils.process import (
cleanup_processes,
cleanup_threads,
monitor_process,
set_proc_title,
start_tail_processes,
)

SFT_TOML = "sft.toml"
SFT_SBATCH = "sft.sbatch"
Expand Down Expand Up @@ -161,11 +167,11 @@ def sft_local(config: SFTConfig):
monitor_threads.append(monitor_thread)

logger.success("Startup complete. Showing trainer logs...")
tail_process = Popen(
f"tail -F '{log_dir / 'trainer.log'}' | sed -u 's/^\\[[a-zA-Z]*[0-9]*\\]://'",
shell=True,
tail_processes = start_tail_processes(
log_dir / "trainer.log",
strip_torchrun_prefix=True,
)
processes.append(tail_process)
processes.extend(tail_processes)

stop_event.wait()

Expand Down
25 changes: 25 additions & 0 deletions src/prime_rl/utils/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from prime_rl.utils.logger import get_logger

PRIME_RL_PROC_PREFIX = "PRIME-RL"
TORCHRUN_LOG_PREFIX_PATTERN = r"s/^\[[a-zA-Z]*[0-9]*\]://"


def set_proc_title(name: str) -> None:
Expand Down Expand Up @@ -62,6 +63,30 @@ def cleanup_processes(processes: list[Popen]):
get_logger().debug(f"Cleaned up process {process.pid}")


def start_tail_processes(
log_path: str | os.PathLike[str],
*,
strip_torchrun_prefix: bool = False,
) -> list[Popen]:
"""Start subprocesses that stream a log file without invoking a shell."""
tail_process = Popen(
["tail", "-F", os.fspath(log_path)],
stdout=subprocess.PIPE if strip_torchrun_prefix else None,
)
if not strip_torchrun_prefix:
return [tail_process]

if tail_process.stdout is None:
raise RuntimeError("tail process stdout was not captured")

sed_process = Popen(
["sed", "-u", TORCHRUN_LOG_PREFIX_PATTERN],
stdin=tail_process.stdout,
)
tail_process.stdout.close()
return [tail_process, sed_process]


def monitor_process(process: Popen, stop_event: Event, error_queue: list, process_name: str):
"""Monitor a subprocess and signal errors via shared queue."""
process.wait()
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/utils/test_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import subprocess

from prime_rl.utils import process


class FakeStdout:
def __init__(self) -> None:
self.closed = False

def close(self) -> None:
self.closed = True


class FakePopen:
def __init__(self, cmd, **kwargs) -> None:
self.cmd = cmd
self.kwargs = kwargs
self.stdout = FakeStdout() if kwargs.get("stdout") is subprocess.PIPE else None


def test_start_tail_processes_uses_argv_without_shell(tmp_path, monkeypatch):
calls = []

def fake_popen(cmd, **kwargs):
popen = FakePopen(cmd, **kwargs)
calls.append(popen)
return popen

monkeypatch.setattr(process, "Popen", fake_popen)

log_path = tmp_path / "trainer '$(touch injected)'.log"

started = process.start_tail_processes(log_path)

assert started == calls[:1]
assert calls[0].cmd == ["tail", "-F", str(log_path)]
assert "shell" not in calls[0].kwargs


def test_start_tail_processes_can_strip_torchrun_prefix_without_shell(tmp_path, monkeypatch):
calls = []

def fake_popen(cmd, **kwargs):
popen = FakePopen(cmd, **kwargs)
calls.append(popen)
return popen

monkeypatch.setattr(process, "Popen", fake_popen)

log_path = tmp_path / "trainer.log"

started = process.start_tail_processes(log_path, strip_torchrun_prefix=True)

assert started == calls
assert calls[0].cmd == ["tail", "-F", str(log_path)]
assert calls[0].kwargs == {"stdout": subprocess.PIPE}
assert calls[0].stdout is not None
assert calls[0].stdout.closed is True
assert calls[1].cmd == ["sed", "-u", process.TORCHRUN_LOG_PREFIX_PATTERN]
assert calls[1].kwargs == {"stdin": calls[0].stdout}
assert "shell" not in calls[1].kwargs