From e3fe296b87903f294ad3e9bf7bd342a62cd225bf Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Mon, 18 May 2026 17:21:40 -0700 Subject: [PATCH] fix(trainer): suppress duplicated logs on non-zero ranks Every torchrun rank writes to its own stdout, but in k8s all rank stdouts merge into a single Loki stream grouped by role=trainer. The dashboard's Trainer log tab then shows N copies of every line (one per GPU), e.g. 8 identical 'Starting training loop' lines and 8 near- identical 'Step 0 | ...' lines on an 8-GPU pod. The existing torchrun --local-ranks-filter plumbing in entrypoints/rl.py only applies to the launcher-managed single-node path; on k8s the trainer is invoked directly via torchrun, bypassing that filter. Fix it at the logger level: add a rank_zero_only flag to setup_logger that builds a sink-less loguru instance on non-zero global ranks (read from torchrun's RANK env var, set before dist init). The RL and SFT trainer entrypoints opt in. Default behavior is unchanged for all other setup_logger callers (orchestrator, inference, launcher). --- src/prime_rl/trainer/rl/train.py | 4 ++++ src/prime_rl/trainer/sft/train.py | 4 ++++ src/prime_rl/utils/logger.py | 21 ++++++++++++++++----- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/src/prime_rl/trainer/rl/train.py b/src/prime_rl/trainer/rl/train.py index fc03e89f3b..2e29bd8bf0 100644 --- a/src/prime_rl/trainer/rl/train.py +++ b/src/prime_rl/trainer/rl/train.py @@ -73,9 +73,13 @@ def train(config: TrainerConfig): # Setup world and logger world = get_world() + # rank_zero_only suppresses logs on non-zero ranks. Every rank's stdout + # is merged in k8s/Loki, so without this each line appears N times in the + # dashboard's trainer log tab (one per GPU). logger = setup_logger( config.log.level, json_logging=config.log.json_logging, + rank_zero_only=True, ) logger.info(f"Starting RL trainer in {world} in {config.output_dir}") diff --git a/src/prime_rl/trainer/sft/train.py b/src/prime_rl/trainer/sft/train.py index 1c12b342ee..d19cd7eb5b 100644 --- a/src/prime_rl/trainer/sft/train.py +++ b/src/prime_rl/trainer/sft/train.py @@ -62,9 +62,13 @@ def train(config: SFTConfig): # Setup world and logger world = get_world() + # rank_zero_only suppresses logs on non-zero ranks. Every rank's stdout + # is merged in k8s/Loki, so without this each line appears N times in the + # dashboard's trainer log tab (one per GPU). logger = setup_logger( config.log.level, json_logging=config.log.json_logging, + rank_zero_only=True, ) logger.info(f"Starting SFT trainer in {world}") diff --git a/src/prime_rl/utils/logger.py b/src/prime_rl/utils/logger.py index 953ddbcb74..43e8d3eecc 100644 --- a/src/prime_rl/utils/logger.py +++ b/src/prime_rl/utils/logger.py @@ -1,5 +1,6 @@ import json as json_module import logging +import os import sys import traceback from typing import Any @@ -88,6 +89,7 @@ def setup_logger( log_level: str = "info", tag: str | None = None, json_logging: bool = False, + rank_zero_only: bool = False, ): global _LOGGER, _JSON_LOGGING _JSON_LOGGING = json_logging @@ -96,6 +98,13 @@ def setup_logger( if _LOGGER is not None: _LOGGER.remove() + # When running under torchrun, every rank writes to its own stdout but all + # stdout streams are merged in k8s/Loki, so each log line shows up N times. + # rank_zero_only=True suppresses output on non-zero ranks. We read the + # global RANK env var (set by torchrun before the process starts) so this + # works even before torch.distributed.init_process_group is called. + is_silent_rank = rank_zero_only and int(os.environ.get("RANK", "0")) != 0 + # Format message with optional tag prefix tag_prefix = f"[{tag}] " if tag else "" message = "".join( @@ -135,11 +144,13 @@ def setup_logger( if json_logging and tag: logger = logger.bind(tag=tag) - # Install console handler (enqueue=True only for JSON mode to avoid blocking in async contexts) - if json_logging: - logger.add(json_sink, level=log_level.upper(), enqueue=True) - else: - logger.add(sys.stdout, format=format, level=log_level.upper(), colorize=True) + # Install console handler (enqueue=True only for JSON mode to avoid blocking in async contexts). + # Silent ranks get a logger with no sinks so all log calls become no-ops. + if not is_silent_rank: + if json_logging: + logger.add(json_sink, level=log_level.upper(), enqueue=True) + else: + logger.add(sys.stdout, format=format, level=log_level.upper(), colorize=True) # Disable critical logging logger.critical = lambda _: None