From 21f2646dbc4482ea8a8fbcd82ff0a37d45badbc0 Mon Sep 17 00:00:00 2001 From: Benjamin Feuer Date: Fri, 20 Mar 2026 09:48:15 -0400 Subject: [PATCH] feat: add multi-node WandB system metrics aggregation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In multi-node training, GPU utilization and system metrics are only captured for the head node. This adds a WandbNodeLogger Ray actor spawned on each worker node, using wandb mode="shared" to aggregate system metrics (GPU util, memory) from all nodes into a single run. Single-node training is unaffected — the WandbNodeLogger only spawns when Ray detects multiple alive nodes. Co-Authored-By: Claude Opus 4.6 (1M context) --- skyrl/train/utils/tracking.py | 122 +++++++++++++++++++++++++++++++++- 1 file changed, 120 insertions(+), 2 deletions(-) diff --git a/skyrl/train/utils/tracking.py b/skyrl/train/utils/tracking.py index 30da58dc11..620561f418 100644 --- a/skyrl/train/utils/tracking.py +++ b/skyrl/train/utils/tracking.py @@ -22,12 +22,49 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +import ray from loguru import logger from omegaconf import DictConfig, OmegaConf +from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy from skyrl.train.config import SkyRLTrainConfig, get_config_as_dict +@ray.remote +class WandbNodeLogger: + """ + A Ray actor that initializes wandb on a specific node to capture system metrics. + + Uses wandb's shared mode to aggregate GPU utilization and other system metrics + from all nodes into a single wandb run. + """ + + def __init__(self, project_name, experiment_name, config, run_id, group_name, x_label): + import wandb + + run = wandb.init( + project=project_name, + name=experiment_name, + id=run_id, + config=config, + resume="allow", + group=group_name, + job_type="worker_monitor", + settings=wandb.Settings( + mode="shared", + x_primary=False, + x_update_finish_state=False, + x_label=x_label, + ), + ) + self.wandb = run + + def finish_run(self): + if self.wandb: + self.wandb.finish() + self.wandb = None + + # TODO(tgriggs): Test all backends. class Tracking: supported_backends = ["wandb", "mlflow", "swanlab", "tensorboard", "console"] @@ -49,8 +86,30 @@ def __init__( if "wandb" in backends: import wandb - wandb.init(project=project_name, name=experiment_name, config=get_config_as_dict(config)) - self.logger["wandb"] = wandb + current_node_ip = None + if ray.is_initialized(): + try: + current_node_ip = ray.util.get_node_ip_address() + except Exception as e: + logger.warning(f"Failed to get node IP address. Error: {e}. " "Skipping multi-node wandb logging.") + + run = wandb.init( + project=project_name, + name=experiment_name, + config=get_config_as_dict(config), + group=experiment_name, + resume="allow", + settings=wandb.Settings( + mode="shared", + x_primary=True, + x_label=f"node-{current_node_ip or 'head'}", + ), + ) + run_id = run.id + self.logger["wandb"] = run + self._prepare_worker_nodes_systems_logging_wandb( + project_name, experiment_name, run_id, config, current_node_ip + ) if "mlflow" in backends: self.logger["mlflow"] = _MlflowLoggingAdapter(project_name, experiment_name, config) @@ -81,6 +140,58 @@ def __init__( self.console_logger = ConsoleLogger() self.logger["console"] = self.console_logger + def _prepare_worker_nodes_systems_logging_wandb( + self, project_name, experiment_name, run_id, config, current_node_ip + ): + """ + In multi-node training, spawn WandbNodeLogger actors on each worker node to capture + system metrics like GPU utilization. Uses wandb mode="shared" to aggregate system + metrics from all nodes into the same wandb run. + """ + self.remote_loggers = [] + + if current_node_ip is None: + logger.warning("Node IP unknown, skipping multi-node wandb logging") + return + + if not ray.is_initialized(): + logger.warning("Ray is not initialized, skipping distributed wandb logging") + return + + try: + nodes = ray.nodes() + + for node in nodes: + if not node["Alive"]: + continue + + node_ip = node["NodeManagerAddress"] + if node_ip == current_node_ip: + continue + + try: + logger_actor = WandbNodeLogger.options( + num_cpus=0.1, + scheduling_strategy=NodeAffinitySchedulingStrategy( + node_id=node["NodeID"], + soft=False, + ), + ).remote( + project_name=project_name, + experiment_name=experiment_name, + config=get_config_as_dict(config), + run_id=run_id, + group_name=experiment_name, + x_label=f"node-{node_ip}", + ) + self.remote_loggers.append(logger_actor) + logger.info(f"WandbNodeLogger initialized on 'node-{node_ip}'") + except Exception as e: + logger.warning(f"Failed to spawn WandbNodeLogger on {node_ip}: {e}") + + except Exception as e: + logger.warning(f"Failed to setup distributed wandb logging: {e}") + def log(self, data, step, commit=False): for logger_name, logger_instance in self.logger.items(): if logger_name == "wandb": @@ -89,6 +200,13 @@ def log(self, data, step, commit=False): logger_instance.log(data=data, step=step) def finish(self): + # Finish remote wandb loggers on worker nodes first + for remote_logger in getattr(self, "remote_loggers", []): + try: + ray.get(remote_logger.finish_run.remote(), timeout=10) + except Exception as e: + logger.warning(f"Failed to finish remote WandbNodeLogger: {e}") + for logger_name, logger_instance in self.logger.items(): # NOTE (sumanthrh): We use a try-except block here while finishing tracking. # This is because wandb often errors out with a BrokenPipeError when closing.