Skip to content

Conversation

@kohankhaki
Copy link
Collaborator

This pull request updates the distributed training example for DDP with Submitit, clarifying how environment variables are set and refactoring the training script. The changes ensure that distributed environment variables are properly initialized, configuration is cleaner, and logging/checkpointing are more robust.

Distributed Training Initialization and Environment Setup:

  • Added a detailed explanation in README.md about how Submitit does not automatically set PyTorch DDP environment variables, and clarified the standard pattern for initializing distributed environments using submitit.JobEnvironment() for both single-node and multi-node jobs.
  • Refactored the training script (train.py) to modularize distributed environment setup: added methods for wrapping models with DDP, configuring training, extracting distributed config from Submitit, preparing the environment (setting RANK, LOCAL_RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT), logging run configuration, and setting seeds.

Configuration and Resource Allocation:

  • Updated the default for tasks_per_node in _global.yaml to select from compute.tasks_per_node or fall back to compute.gpus_per_node, improving resource allocation flexibility for distributed jobs.

Training Script Improvements:

  • Refactored the main training loop to use the new helper methods for configuration, environment setup, logging, and checkpointing. This improves readability, maintainability, and correctness when running under Submitit.
  • Fixed average loss calculation to account for all processes in DDP, and improved checkpointing logic to ensure synchronization and correct saving in distributed settings.

@kohankhaki kohankhaki marked this pull request as draft October 15, 2025 21:01
@kohankhaki kohankhaki requested a review from scarere October 15, 2025 21:01
@kohankhaki kohankhaki marked this pull request as ready for review October 16, 2025 17:38
Copy link
Contributor

@scarere scarere left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good and seems to be working. Mostly minor comments about adding more documentation and logging as well as switching to TorchDistributedEnv to help simplify the script. An example can be found here

nodes: ${oc.select:compute.nodes,null}
gpus_per_node: ${oc.select:compute.slurm.gpus_per_node, ${compute.gpus_per_node}}
tasks_per_node: 1
tasks_per_node: ${oc.select:compute.tasks_per_node, ${compute.gpus_per_node}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The preset compute configs request way too many cpu's. For example killarney/l40s_2x requests 64 cpu's per task. However there are 2 gpus, hence 2 tasks, hence requests 128 cpu's. The l40 nodes only have 64 cpus total. We should modify all configs such that cpus_per_task = total_cpus * num_requested_gpus / num_gpus_on_node. We should also apply a similar scaling to mem_gb. In general if we are requesting half the gpu's on a node, then we should use half of all the other resources (mem, cpus's etc.).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for each l40 we should request 16 gpus and 128GB of memory. Now thinking about this, more than 128GB might be unnecessary, even if there is room for it. I'll leave it up to you whether to scale memory as well or just leave it fixed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching that, you’re right. I’ve updated all compute presets to scale cpus_per_task and mem_gb proportionally to the number of GPUs requested per node.
For example, Killarney L40S nodes have 64 CPUs and 512 GB total memory across 4 GPUs, so each GPU now requests 16 CPUs and 128 GB.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to fallback to 1 if compute.gpus_per_node is also not specified? Say on a cpu compute config? Maybe that will never happen though

Unlike `torchrun`, Submitit is a **job scheduler integration**, not a distributed orchestrator. It spawns one process per GPU (or per `tasks_per_node`), but it does **not automatically set** the PyTorch environment variables (`RANK`, `LOCAL_RANK`, `WORLD_SIZE`, `MASTER_ADDR`, `MASTER_PORT`) required by `torch.distributed`.

**Rank:** Integer ID for a single gpu. Unique across all nodes. (from `0` to `world_size - 1`)
Therefore, this project explicitly initializes the distributed environment inside the training script using `submitit.JobEnvironment()`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe elaborate a little bit more here. Although submitit doesn't set the variables automatically, it does automatically determine the world size, rank and local rank as well as a bunch of other useful environment variables. The user doesn't have to actually manually set the local rank for each gpu, they just need to retrieve the environment from JobEnvironment. Additionally if we switch to using slurm.helpers.TorchDistributedEnvironment then we should document that here. Worth explaining that the latter is essentially an extension of the former. JobEnvironment has additional info that might be useful in other unique/custom cases and so good to make users aware of that as well


os.environ.setdefault("MASTER_PORT", "29500")

def _log_run_configuration(self, seed, world_size, local_rank, rank):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you log for all ranks, will the hydra log file contain duplicated logs (one for each rank)? If so maybe you can log only on rank 0, but print to stdout on the other ranks? I think this would give better visibility into whats going on. It is also a might be a good way to teach users to use log and print differently. We can add documentation saying that log prints the output to the "global" hydra log for the run and the submitit stdout, but print will hide it from the "global" hydra log for the run and print it only to the stdout which is specific to the process (in this case the rank/gpu).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, agreed. Logging from all ranks would duplicate entries in the Hydra log. I’ve already kept logger calls restricted to rank 0 so the global Hydra log stays clean. I added a short note in the README explaining that logger writes to the global Hydra log while print() can be used for rank-local stdout visibility on other ranks. I also added print() statements during DDP initialization to give per-rank visibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea, what if on initialization you had something like:

if rank == 0:
    self.log_fn = logger.info
else:
    self.log_fn = print

Then use self.log_fn throughout the rest of the script. Only rank 0 logs will be sent to hydra, all other ranks will just print to stdout. Downside is for the most part the logs will be identical, upside is greater visibility into whats going on if debugging specific ranks. If you think thats overkill however happy to keep the current solution of an initial print statement confirming the rank was initialized.

Copy link
Contributor

@scarere scarere left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple final remarks. Feel free to address what you think is relevant and merge when you are done.

nodes: ${oc.select:compute.nodes,null}
gpus_per_node: ${oc.select:compute.slurm.gpus_per_node, ${compute.gpus_per_node}}
tasks_per_node: 1
tasks_per_node: ${oc.select:compute.tasks_per_node, ${compute.gpus_per_node}}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to fallback to 1 if compute.gpus_per_node is also not specified? Say on a cpu compute config? Maybe that will never happen though


os.environ.setdefault("MASTER_PORT", "29500")

def _log_run_configuration(self, seed, world_size, local_rank, rank):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an idea, what if on initialization you had something like:

if rank == 0:
    self.log_fn = logger.info
else:
    self.log_fn = print

Then use self.log_fn throughout the rest of the script. Only rank 0 logs will be sent to hydra, all other ranks will just print to stdout. Downside is for the most part the logs will be identical, upside is greater visibility into whats going on if debugging specific ranks. If you think thats overkill however happy to keep the current solution of an initial print statement confirming the rank was initialized.

if rank == 0:
logger.info(f"[Rank {rank}] Initialized on device: {device}")
else:
print(f"[Rank {rank}] Initialized on device: {device}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if rank == 0 conditional happens basically twice in a row here. You also are already stating which device is being used for rank 0. I think you can remove the second if statement.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed this one.
Also I fixed your earlier comment on fallback in task_per_node.
I like this suggestion, it’s a neat, low-friction way to get per-rank visibility when you need it. My only concern is noise: for normal runs the duplicated logs from every rank can overwhelm the hydra logs and make analysis harder.

@kohankhaki kohankhaki merged commit e82c6da into main Oct 23, 2025
@kohankhaki kohankhaki deleted the fix-mlp-ddp branch October 23, 2025 00:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants