Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions templates/configs/_global.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ hydra:
submitit_folder: ${hydra.sweep.dir}/submitit_logs/%j
nodes: ${oc.select:compute.nodes,null}
gpus_per_node: ${oc.select:compute.slurm.gpus_per_node, ${compute.gpus_per_node}}
tasks_per_node: 1
tasks_per_node: ${oc.select:compute.tasks_per_node, ${compute.gpus_per_node}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The preset compute configs request way too many cpu's. For example killarney/l40s_2x requests 64 cpu's per task. However there are 2 gpus, hence 2 tasks, hence requests 128 cpu's. The l40 nodes only have 64 cpus total. We should modify all configs such that cpus_per_task = total_cpus * num_requested_gpus / num_gpus_on_node. We should also apply a similar scaling to mem_gb. In general if we are requesting half the gpu's on a node, then we should use half of all the other resources (mem, cpus's etc.).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for each l40 we should request 16 gpus and 128GB of memory. Now thinking about this, more than 128GB might be unnecessary, even if there is room for it. I'll leave it up to you whether to scale memory as well or just leave it fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching that, you’re right. I’ve updated all compute presets to scale cpus_per_task and mem_gb proportionally to the number of GPUs requested per node.
For example, Killarney L40S nodes have 64 CPUs and 512 GB total memory across 4 GPUs, so each GPU now requests 16 CPUs and 128 GB.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to fallback to 1 if compute.gpus_per_node is also not specified? Say on a cpu compute config? Maybe that will never happen though

cpus_per_task: ${compute.cpus_per_task}
mem_gb: ${compute.mem_gb}
timeout_min: ${compute.timeout_min}
Expand All @@ -31,4 +31,3 @@ hydra:
account: ${user.slurm.account}
max_num_timeout: 2
additional_parameters: ${oc.select:user.slurm.additional_parameters, {}}

27 changes: 19 additions & 8 deletions templates/src/mlp/ddp/README.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,28 @@
# Distributed Data Parallel Example
# MLP Distributed Data Parallel Template

> :warning: WIP: This template is a work in progress and does not use DDP in its current state.
*Data Parallelism* lets you to split your data across multiple accelerators so that you can train your model faster!

*Data Parallelism* lets you to split your data across multiple accelerators so that you can train your model faster!

Most of the time all your accelerators (gpus) will be on the same machine (node), and that simplifies things. However if you are using a large number of gpus that can't fit on a single machine, then you'll have to use multiple machines (nodes). For example, on the Killarney cluster, L40's have a maximum of 4 per node and H100's have a maximum of 8 per nodes. Data Parallelism across multiple nodes is referred to as *Distributed Data Parallelism* (DDP). By default DDP works for both single node and multi-node settings.
Most of the time all your accelerators (GPUs) will be on the same machine (node), and that simplifies things. However if you are using a large number of GPUs that can't fit on a single machine, then you'll have to use multiple machines (nodes). For example, on the Killarney cluster, L40's have a maximum of 4 per node and H100's have a maximum of 8 per nodes. Data Parallelism across multiple nodes is referred to as *Distributed Data Parallelism* (DDP). By default DDP works for both single node and multi-node settings.

This example implements a simple MLP using DDP.

## DDP Background

**World Size:** The total number of gpu's across all nodes
**World Size:** The total number of GPU's across all nodes

**Rank:** Integer ID for a single GPU. Unique across all nodes. (from `0` to `world_size - 1`)

**Local Rank:** Integer ID for a single GPU. Unique only within a node. (from `0` to `num_gpus_per_node - 1`)

## DDP Setup

Unlike `torchrun`, Submitit is a **job scheduler integration**, not a distributed orchestrator. It spawns one process per GPU (or per `tasks_per_node`), but it does **not automatically set** the PyTorch environment variables (`RANK`, `LOCAL_RANK`, `WORLD_SIZE`, `MASTER_ADDR`, `MASTER_PORT`) required by `torch.distributed`.

**Rank:** Integer ID for a single gpu. Unique across all nodes. (from `0` to `world_size - 1`)
Therefore, this project explicitly initializes the distributed environment inside the training script using `submitit.JobEnvironment()`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe elaborate a little bit more here. Although submitit doesn't set the variables automatically, it does automatically determine the world size, rank and local rank as well as a bunch of other useful environment variables. The user doesn't have to actually manually set the local rank for each gpu, they just need to retrieve the environment from JobEnvironment. Additionally if we switch to using slurm.helpers.TorchDistributedEnvironment then we should document that here. Worth explaining that the latter is essentially an extension of the former. JobEnvironment has additional info that might be useful in other unique/custom cases and so good to make users aware of that as well

This pattern is the standard way to perform DDP initialization with Submitit when not using `torchrun`
([MosaicML Docs](https://docs.mosaicml.com/projects/composer/en/stable/examples/training_with_submitit.html),
[Hydra Submitit Launcher](https://hydra.cc/docs/plugins/submitit_launcher/),
[PyTorch Forum Discussion](https://discuss.pytorch.org/t/using-submitit-for-distributed-training/121881),
[Fairseq Example](https://github.com/facebookresearch/fairseq/blob/main/examples/language_model/submitit_train.py)).

**Local Rank:** Integer ID for a single gpu. Unique only within a node. (from `0` to `num_gpus_per_node - 1`)
It works for both **single-node** and **multi-node** jobs as long as the `MASTER_ADDR` points to a hostname reachable from all nodes.
110 changes: 64 additions & 46 deletions templates/src/mlp/ddp/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Distributed MLP training using PyTorch DDP."""

import os
import logging
import os

import submitit
import torch
import torch.distributed as dist
from omegaconf import DictConfig, OmegaConf
from torch import nn, optim
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset
from omegaconf import DictConfig, OmegaConf


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -79,6 +79,52 @@ def _setup_distributed(self, rank, world_size):
world_size=world_size,
)

def _wrap_distributed(self, model, world_size, local_rank):
if world_size > 1:
return nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank] if torch.cuda.is_available() else None,
)
return model

def _configure_training(self, cfg):
lr = OmegaConf.select(cfg, "trainer.learning_rate", default=1e-3)
num_epochs = OmegaConf.select(cfg, "trainer.num_epochs", default=1000)
seed = OmegaConf.select(cfg, "trainer.seed", default=42)
return lr, num_epochs, seed

def _get_distributed_config(self):
job_env = submitit.JobEnvironment()
return job_env, job_env.global_rank, job_env.local_rank, job_env.num_tasks

def _prepare_environment(self, job_env, rank, local_rank, world_size):
os.environ.setdefault("RANK", str(rank))
os.environ.setdefault("LOCAL_RANK", str(local_rank))
os.environ.setdefault("WORLD_SIZE", str(world_size))

if "MASTER_ADDR" not in os.environ:
master_addr = (
job_env.hostnames[0]
if hasattr(job_env, "hostnames")
else job_env.hostname
)
os.environ["MASTER_ADDR"] = str(master_addr)

os.environ.setdefault("MASTER_PORT", "29500")

def _log_run_configuration(self, seed, world_size, local_rank, rank):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you log for all ranks, will the hydra log file contain duplicated logs (one for each rank)? If so maybe you can log only on rank 0, but print to stdout on the other ranks? I think this would give better visibility into whats going on. It is also a might be a good way to teach users to use log and print differently. We can add documentation saying that log prints the output to the "global" hydra log for the run and the submitit stdout, but print will hide it from the "global" hydra log for the run and print it only to the stdout which is specific to the process (in this case the rank/gpu).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, agreed. Logging from all ranks would duplicate entries in the Hydra log. I’ve already kept logger calls restricted to rank 0 so the global Hydra log stays clean. I added a short note in the README explaining that logger writes to the global Hydra log while print() can be used for rank-local stdout visibility on other ranks. I also added print() statements during DDP initialization to give per-rank visibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea, what if on initialization you had something like:

if rank == 0:
    self.log_fn = logger.info
else:
    self.log_fn = print

Then use self.log_fn throughout the rest of the script. Only rank 0 logs will be sent to hydra, all other ranks will just print to stdout. Downside is for the most part the logs will be identical, upside is greater visibility into whats going on if debugging specific ranks. If you think thats overkill however happy to keep the current solution of an initial print statement confirming the rank was initialized.

if rank != 0:
return
logger.info(f"Starting DDP MLP training with seed {seed}")
logger.info(f"World size: {world_size}, Local rank: {local_rank}")
if torch.cuda.is_available():
logger.info(f"Number of available GPUs: {torch.cuda.device_count()}")

def _set_seed(self, seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)

def _initialize_device_and_model(self, cfg, local_rank):
"""Initialize device and model."""
input_dim = OmegaConf.select(cfg, "trainer.input_dim", default=10)
Expand Down Expand Up @@ -109,7 +155,7 @@ def _initialize_data_and_loader(self, cfg, world_size, rank):
num_classes = OmegaConf.select(cfg, "trainer.num_classes", default=3)
batch_size = OmegaConf.select(cfg, "trainer.batch_size", default=32)

dataset = create_dummy_data(1000, input_dim, num_classes)
dataset = create_dummy_data(100000, input_dim, num_classes)
sampler = (
DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=True
Expand Down Expand Up @@ -157,7 +203,6 @@ def _train_epoch(
device,
epoch,
world_size,
rank,
):
"""Train for one epoch and return metrics."""
# Set epoch for DistributedSampler to ensure proper shuffling across epochs
Expand Down Expand Up @@ -193,64 +238,35 @@ def _train_epoch(

def __call__(self, cfg):
"""Train the MLP model with DDP."""
cfg : DictConfig = OmegaConf.create(cfg) # Ensure cfg is a DictConfig
cfg: DictConfig = OmegaConf.create(cfg)

# Create output directory
out_dir = cfg.paths.out_dir
os.makedirs(out_dir, exist_ok=True)

# Get ckpt dir
self.ckpt_dir = self._latest_checkpoint(out_dir)

# Configuration
lr = OmegaConf.select(cfg, "trainer.learning_rate", default=1e-3)
num_epochs = OmegaConf.select(cfg, "trainer.num_epochs", default=1000)
seed = OmegaConf.select(cfg, "trainer.seed", default=42)
lr, num_epochs, seed = self._configure_training(cfg)
job_env, rank, local_rank, world_size = self._get_distributed_config()

# Get distributed training info from environment
# TODO: None of these env vars are actually set at the moment. Need to fix this example.
rank = int(os.environ.get("RANK", "0"))
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
world_size = int(os.environ.get("WORLD_SIZE", "1"))
self._prepare_environment(job_env, rank, local_rank, world_size)
self._set_seed(seed)
self._log_run_configuration(seed, world_size, local_rank, rank)

if rank == 0:
logger.info(f"Starting DDP MLP training with seed {seed}")
logger.info(f"World size: {world_size}, Local rank: {local_rank}")

# Set seed for reproducibility (same seed on all processes)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
logger.info(f"Number of available GPUs: {torch.cuda.device_count()}")

# Setup distributed training
self._setup_distributed(rank, world_size)

# Setup device and model
device, model = self._initialize_device_and_model(cfg, local_rank)

if rank == 0:
logger.info(f"Using device: {device}")

# Wrap model with DDP
if world_size > 1:
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[local_rank] if torch.cuda.is_available() else None,
)
model = self._wrap_distributed(model, world_size, local_rank)

# Setup data and training
loader, sampler = self._initialize_data_and_loader(cfg, world_size, rank)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

# Resume from checkpoint if available
start_epoch = self._load_checkpoint_if_exists(model, optimizer, device, rank)

if rank == 0:
logger.info(f"Training from epoch {start_epoch} to {num_epochs}...")

# Training loop with DDP
for epoch in range(start_epoch, num_epochs):
loss_sum, correct, total = self._train_epoch(
model,
Expand All @@ -261,18 +277,20 @@ def __call__(self, cfg):
device,
epoch,
world_size,
rank,
)

avg_loss = loss_sum / (len(loader) * world_size)
acc = 100.0 * correct / total
should_checkpoint = epoch % 100 == 0 or epoch == num_epochs - 1

# Log metrics only on rank 0
if rank == 0:
acc = 100.0 * correct / total
avg_loss = loss_sum / len(loader)
logger.info(f"Epoch {epoch}: loss={avg_loss:.4f} acc={acc:.2f}%")

if epoch % 100 == 0 or epoch == num_epochs - 1:
if world_size > 1:
dist.barrier()
if should_checkpoint:
if world_size > 1:
dist.barrier()
if rank == 0:
self._save_checkpoint(
model, optimizer, epoch, out_dir, avg_loss, acc, rank
)
Expand Down