Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 120 additions & 2 deletions skyrl/train/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,49 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import ray
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy

from skyrl.train.config import SkyRLTrainConfig, get_config_as_dict


@ray.remote
class WandbNodeLogger:
"""
A Ray actor that initializes wandb on a specific node to capture system metrics.

Uses wandb's shared mode to aggregate GPU utilization and other system metrics
from all nodes into a single wandb run.
"""

def __init__(self, project_name, experiment_name, config, run_id, group_name, x_label):
import wandb

run = wandb.init(
project=project_name,
name=experiment_name,
id=run_id,
config=config,
resume="allow",
group=group_name,
job_type="worker_monitor",
settings=wandb.Settings(
mode="shared",
x_primary=False,
x_update_finish_state=False,
x_label=x_label,
),
)
self.wandb = run
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To ensure proper resource cleanup and to explicitly terminate the WandB run on worker nodes, consider adding a finish_run method to the WandbNodeLogger class. This method could call self.wandb.finish() and then be invoked from the main Tracking class's finish or __del__ methods for each remote_logger_actor. This ensures all metrics are properly synced and resources are released when the training run concludes.

Suggested change
self.wandb = run
self.wandb = run
def finish_run(self):
if self.wandb:
self.wandb.finish()
self.wandb = None


def finish_run(self):
if self.wandb:
self.wandb.finish()
self.wandb = None


# TODO(tgriggs): Test all backends.
class Tracking:
supported_backends = ["wandb", "mlflow", "swanlab", "tensorboard", "console"]
Expand All @@ -49,8 +86,30 @@ def __init__(
if "wandb" in backends:
import wandb

wandb.init(project=project_name, name=experiment_name, config=get_config_as_dict(config))
self.logger["wandb"] = wandb
current_node_ip = None
if ray.is_initialized():
try:
current_node_ip = ray.util.get_node_ip_address()
except Exception as e:
logger.warning(f"Failed to get node IP address. Error: {e}. " "Skipping multi-node wandb logging.")

run = wandb.init(
project=project_name,
name=experiment_name,
config=get_config_as_dict(config),
group=experiment_name,
resume="allow",
settings=wandb.Settings(
mode="shared",
x_primary=True,
x_label=f"node-{current_node_ip or 'head'}",
),
)
run_id = run.id
self.logger["wandb"] = run
self._prepare_worker_nodes_systems_logging_wandb(
project_name, experiment_name, run_id, config, current_node_ip
)

if "mlflow" in backends:
self.logger["mlflow"] = _MlflowLoggingAdapter(project_name, experiment_name, config)
Expand Down Expand Up @@ -81,6 +140,58 @@ def __init__(
self.console_logger = ConsoleLogger()
self.logger["console"] = self.console_logger

def _prepare_worker_nodes_systems_logging_wandb(
self, project_name, experiment_name, run_id, config, current_node_ip
):
"""
In multi-node training, spawn WandbNodeLogger actors on each worker node to capture
system metrics like GPU utilization. Uses wandb mode="shared" to aggregate system
metrics from all nodes into the same wandb run.
"""
self.remote_loggers = []

if current_node_ip is None:
logger.warning("Node IP unknown, skipping multi-node wandb logging")
return

if not ray.is_initialized():
logger.warning("Ray is not initialized, skipping distributed wandb logging")
return

try:
nodes = ray.nodes()

for node in nodes:
if not node["Alive"]:
continue

node_ip = node["NodeManagerAddress"]
if node_ip == current_node_ip:
continue

try:
logger_actor = WandbNodeLogger.options(
num_cpus=0.1,
scheduling_strategy=NodeAffinitySchedulingStrategy(
node_id=node["NodeID"],
soft=False,
),
).remote(
project_name=project_name,
experiment_name=experiment_name,
config=get_config_as_dict(config),
run_id=run_id,
group_name=experiment_name,
x_label=f"node-{node_ip}",
)
self.remote_loggers.append(logger_actor)
logger.info(f"WandbNodeLogger initialized on 'node-{node_ip}'")
except Exception as e:
logger.warning(f"Failed to spawn WandbNodeLogger on {node_ip}: {e}")

except Exception as e:
logger.warning(f"Failed to setup distributed wandb logging: {e}")

def log(self, data, step, commit=False):
for logger_name, logger_instance in self.logger.items():
if logger_name == "wandb":
Expand All @@ -89,6 +200,13 @@ def log(self, data, step, commit=False):
logger_instance.log(data=data, step=step)

def finish(self):
# Finish remote wandb loggers on worker nodes first
for remote_logger in getattr(self, "remote_loggers", []):
try:
ray.get(remote_logger.finish_run.remote(), timeout=10)
except Exception as e:
logger.warning(f"Failed to finish remote WandbNodeLogger: {e}")

for logger_name, logger_instance in self.logger.items():
# NOTE (sumanthrh): We use a try-except block here while finishing tracking.
# This is because wandb often errors out with a BrokenPipeError when closing.
Expand Down
Loading