Skip to content

Commit 6b1c431

Browse files
authored
Add DCP load benchmarking support (#357)
Extended distributed checkpointing benchmarks to support both save and load operations for DDP and FSDP training scenarios. Previously only checkpoint saving was supported. The implementation adds load benchmark scripts for DDP and FSDP scenarios, extends the run script with --load/--save flags for operation selection, and creates dedicated load configuration files with checkpoint suffix parameters. Enhanced common utilities now include reader support functions while maintaining backward compatibility with existing save operations. This enables users to benchmark checkpoint loading performance alongside existing save benchmarks, providing complete checkpoint operation performance analysis.
1 parent 7337692 commit 6b1c431

File tree

16 files changed

+404
-30
lines changed

16 files changed

+404
-30
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
* Fix SequentialS3Reader seek beyond EOF to clamp position to object size (#362)
77

88
### Other changes
9-
* Added thread_count parameter to S3StorageWriter
9+
* Add benchmark to run DCP Loading Workloads (#357)
10+
* Add thread_count parameter to S3StorageWriter (#370)
1011

1112
## v1.4.3 (July 25, 2025)
1213

s3torchbenchmarking/README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,16 @@ vim ./conf/lightning_checkpointing.yaml # 1. edit config
112112
./utils/run_lightning_benchmarks.sh # 2. run scenario
113113

114114
# PyTorch’s Distributed Checkpointing (DCP) benchmarks
115-
vim ./conf/dcp_ddp.yaml # 1. edit config
116-
vim ./conf/dcp_fsdp.yaml
117-
./utils/run_dcp_ddp_benchmarks.sh # 2. run scenario
115+
vim ./conf/dcp_ddp_load.yaml # 1. edit config
116+
vim ./conf/dcp_fsdp_load.yaml
117+
vim ./conf/dcp_ddp_save.yaml
118+
vim ./conf/dcp_fsdp_save.yaml
119+
# Saving Checkpoint
120+
./utils/run_dcp_ddp_benchmarks.sh # 2. run scenario for saving checkpoint
118121
./utils/run_dcp_fsdp_benchmarks.sh
122+
# Loading Checkpoint
123+
./utils/run_ddp_benchmarks.sh --load # 3. run scenario for loading checkpoint after saving
124+
./utils/run_dcp_fsdp_benchmarks.sh --load
119125
```
120126

121127
> [!NOTE]
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
defaults:
2+
- hydra/callbacks/collate_results
3+
- aws/dynamodb # save run results to DynamoDB -- comment me if not required
4+
- _self_
5+
6+
# S3 bucket to use to save checkpoints.
7+
# NOTE: a non-existing bucket will fail the benchmarks.
8+
s3:
9+
region: ??? # e.g., eu-west-1
10+
uri: ??? # e.g., s3://my-bucket/
11+
# Number of iterations for "saving" a model's checkpoint.
12+
# NOTE: this does not affect model training, as no actual training occurs in these benchmarks.
13+
epochs: 4
14+
15+
hydra:
16+
mode: MULTIRUN
17+
sweep:
18+
dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S}
19+
sweeper:
20+
params:
21+
# Short name of a pre-trained model (from Hugging Face), listed in `models.py`.
22+
+model: ???
23+
# Type of Torch distributed backend (valid options: "nccl", "gloo").
24+
+backend: nccl
25+
# Number of workers.
26+
+world_size: 8
27+
# Checkpoint storage location (valid options: "disk", "s3").
28+
+checkpoint.storage: disk, s3
29+
# Checkpoint storage suffix location generated by save benchmarks, e.g., 2025-09-23-11-05-zmuZ/
30+
+checkpoint.suffix: ???
31+

s3torchbenchmarking/conf/dcp_ddp.yaml renamed to s3torchbenchmarking/conf/dcp_ddp_save.yaml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ hydra:
2323
# Type of Torch distributed backend (valid options: "nccl", "gloo").
2424
+backend: nccl
2525
# Number of workers.
26-
+world_size: 4
26+
+world_size: 8
2727
# Number of threads to use for saving the checkpoints.
28-
+thread_count: 4
28+
+thread_count: 8
2929
# Checkpoint storage location (valid options: "disk", "s3").
30-
+checkpoint.storage: disk, s3
30+
+checkpoint.storage: disk, s3
31+
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
defaults:
2+
- hydra/callbacks/collate_results
3+
- aws/dynamodb # save run results to DynamoDB -- comment me if not required
4+
- _self_
5+
6+
# S3 bucket to use to save checkpoints.
7+
# NOTE: a non-existing bucket will fail the benchmarks.
8+
s3:
9+
region: ??? # e.g., eu-west-1
10+
uri: ??? # e.g., s3://my-bucket/
11+
# Number of iterations for "saving" a model's checkpoint.
12+
# NOTE: this does not affect model training, as no actual training occurs in these benchmarks.
13+
epochs: 4
14+
15+
hydra:
16+
mode: MULTIRUN
17+
sweep:
18+
dir: multirun/${hydra.job.config_name}/${now:%Y-%m-%d_%H-%M-%S}
19+
sweeper:
20+
params:
21+
# Short name of a pre-trained llama v2 model (valid options: "L7b", "L13b", "L30b", "L65b", "L70b").
22+
+model: ???
23+
# Type of Torch distributed backend (valid options: "nccl", "gloo").
24+
+backend: nccl
25+
# Number of workers.
26+
+world_size: 8
27+
# Checkpoint storage location (valid options: "disk", "s3").
28+
+checkpoint.storage: disk, s3
29+
# Sharding strategy (valid options: "full", "hybrid").
30+
+checkpoint.sharding_strategy: full
31+
# Controls whether files are forcibly synced to disk (only relevant for "disk" storage).
32+
# NOTE: We disabled this option to improve performance since FSDP checkpointing with
33+
# forced syncing (maximum durability) was significantly slower than storage throughput.
34+
# This setting has no effect when using "s3" storage.
35+
+checkpoint.sync_files: false
36+
# Checkpoint storage suffix location generated by save benchmarks, e.g., 2025-09-23-11-05-zmuZ/
37+
+checkpoint.suffix: ???
38+

s3torchbenchmarking/src/s3torchbenchmarking/dcp_common.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
import torch.distributed.checkpoint as dcp
1515
from omegaconf import DictConfig
1616
from torch import multiprocessing as mp
17-
from torch.distributed.checkpoint import FileSystemWriter
17+
from torch.distributed.checkpoint import FileSystemWriter, FileSystemReader
18+
1819

1920
from s3torchbenchmarking.benchmark_utils import (
2021
build_random_suffix,
2122
build_checkpoint_uri,
2223
)
23-
from s3torchconnector.dcp import S3StorageWriter
24+
from s3torchconnector.dcp import S3StorageWriter, S3StorageReader
2425

2526
Timestamps = Tuple[float, float]
2627
logger = logging.getLogger(__name__)
@@ -49,6 +50,20 @@ def get_writer(cfg: DictConfig, suffix: str) -> FileSystemWriter:
4950
raise ValueError(f"Storage writer {cfg.checkpoint.storage} not supported")
5051

5152

53+
def get_reader(cfg: DictConfig) -> FileSystemReader:
54+
"""Instantiate a checkpoint reader based on the input config."""
55+
suffix = cfg.checkpoint.suffix
56+
if cfg.checkpoint.storage == "disk":
57+
local_path = Path(cfg.path) / suffix
58+
logger.info("Loading checkpoint from %s (disk)...", local_path)
59+
return dcp.FileSystemReader(local_path)
60+
elif cfg.checkpoint.storage == "s3":
61+
uri = build_checkpoint_uri(cfg.s3.uri, suffix)
62+
logger.info("Loading checkpoint from %s (S3)...", uri)
63+
return S3StorageReader(cfg.s3.region, uri)
64+
raise ValueError(f"Storage reader {cfg.checkpoint.storage} not supported")
65+
66+
5267
def benchmark_common_runner(
5368
cfg: DictConfig,
5469
run_fn,

s3torchbenchmarking/src/s3torchbenchmarking/dcp_ddp/README.md

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,42 @@ where memory requirements per GPU are manageable.
1010

1111
### Purpose
1212

13-
These benchmarks focus on testing the "save" mechanism of PyTorch DCP (`torch.distributed.checkpoint.save`). The primary
14-
objectives are to evaluate the `s3torchconnector` library's performance against other libraries and local storage
15-
options, by measuring the following metrics:
13+
These benchmarks test both "save" and "load" mechanisms of PyTorch DCP (`torch.distributed.checkpoint.save` and `torch.distributed.checkpoint.load`). The primary objectives are to evaluate the `s3torchconnector` library's performance against other libraries and local storage options, by measuring the following metrics:
1614

17-
- Checkpoint saving throughput (in MiB/s);
18-
- Checkpoint "corrected" save durations (in seconds), which exclude the influence of model load duration on the device.
15+
**Save Benchmarks:**
16+
- Checkpoint saving throughput (in MiB/s)
17+
- Checkpoint "corrected" save durations (in seconds), which exclude the influence of model load duration on the device
18+
19+
**Load Benchmarks:**
20+
- Checkpoint loading throughput (in MiB/s)
21+
- Checkpoint "corrected" load durations (in seconds), which exclude the influence of process setup and model loading to device
1922

2023
### Configuration
2124

22-
The benchmark runs can be customized through the [`dcp_ddp.yaml`](../../../conf/dcp_ddp.yaml) file.
25+
The benchmark runs can be customized through configuration files:
26+
27+
- **Save benchmarks**: [`dcp_ddp_save.yaml`](../../../conf/dcp_ddp.yaml)
28+
- **Load benchmarks**: [`dcp_ddp_load.yaml`](../../../conf/dcp_ddp_load.yaml)
29+
30+
The load configuration includes a `checkpoint.suffix` parameter that specifies which saved checkpoint to load.
2331

2432
> [!IMPORTANT]
2533
> A `+path` option is passed to the running script ([`run_dcp_ddp_benchmarks.sh`](../../../utils/run_dcp_ddp_benchmarks.sh)),
2634
> and will be used only if `checkpoint.storage` key includes `disk`.
2735
36+
### Usage
37+
38+
**Save benchmarks (default):**
39+
```bash
40+
./utils/run_dcp_ddp_benchmarks.sh
41+
./utils/run_dcp_ddp_benchmarks.sh --save
42+
```
43+
44+
**Load benchmarks:**
45+
```bash
46+
./utils/run_dcp_ddp_benchmarks.sh --load
47+
```
48+
2849
### References
2950

3051
- https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
import logging
5+
from multiprocessing.queues import Queue
6+
from time import perf_counter
7+
from typing import Tuple
8+
9+
import hydra
10+
import torch
11+
import torch.distributed as dist
12+
import torch.distributed.checkpoint as dcp
13+
from omegaconf import DictConfig
14+
from torch.nn.parallel import DistributedDataParallel
15+
16+
from s3torchbenchmarking.dcp_common import setup, get_reader, benchmark_common_runner
17+
from s3torchbenchmarking.models import get_benchmark_model, BenchmarkModel
18+
19+
Timestamps = Tuple[float, float]
20+
logger = logging.getLogger(__name__)
21+
22+
23+
# TODO: add Structured Config (https://hydra.cc/docs/tutorials/structured_config/intro/)
24+
@hydra.main(version_base=None)
25+
def run_benchmark(cfg: DictConfig) -> dict:
26+
"""DCP benchmarks entry point."""
27+
benchmark_model = get_benchmark_model(cfg.model)
28+
29+
return benchmark_common_runner(cfg, run_ddp_load, (cfg, benchmark_model))
30+
31+
32+
def run_ddp_load(
33+
rank: int, # needs to be passed first (provided by `multiprocessing.spawn` automatically)
34+
cfg: DictConfig,
35+
proxy_model: BenchmarkModel,
36+
suffix: str,
37+
load_timestamps: Queue,
38+
) -> None:
39+
"""Execute the actual code for checkpoint loading.
40+
41+
This function is meant to be executed in subprocesses."""
42+
begin_process = perf_counter()
43+
# Override random suffix with suffix from config
44+
storage_reader = get_reader(cfg)
45+
model_size = proxy_model.size
46+
model = proxy_model.model
47+
48+
setup(cfg.backend, world_size=cfg.world_size, rank=rank)
49+
if cfg.backend == "nccl":
50+
device_id = rank % torch.cuda.device_count()
51+
torch.cuda.set_device(device_id)
52+
model.to(device_id)
53+
model = DistributedDataParallel(model, device_ids=[device_id])
54+
else:
55+
device_id = rank % torch.cpu.device_count()
56+
torch.cpu.set_device(device_id)
57+
model.to(device=torch.device("cpu"))
58+
model = DistributedDataParallel(model)
59+
60+
state_dict = model.state_dict()
61+
62+
begin_load = perf_counter() # also "end_process"
63+
dcp.load(state_dict, storage_reader=storage_reader)
64+
end_load = perf_counter()
65+
66+
# Record the load times excluding the influence of the process setup and model loading to device.
67+
load_timestamps.put(
68+
(begin_process, end_load - (begin_load - begin_process), model_size)
69+
)
70+
71+
dist.destroy_process_group()
72+
73+
74+
if __name__ == "__main__":
75+
run_benchmark()

s3torchbenchmarking/src/s3torchbenchmarking/dcp_ddp/benchmark.py renamed to s3torchbenchmarking/src/s3torchbenchmarking/dcp_ddp/save_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ def run_benchmark(cfg: DictConfig) -> dict:
2626
"""DCP benchmarks entry point."""
2727
benchmark_model = get_benchmark_model(cfg.model)
2828

29-
return benchmark_common_runner(cfg, run_ddp, (cfg, benchmark_model))
29+
return benchmark_common_runner(cfg, run_ddp_save, (cfg, benchmark_model))
3030

3131

32-
def run_ddp(
32+
def run_ddp_save(
3333
rank: int, # needs to be passed first (provided by `multiprocessing.spawn` automatically)
3434
cfg: DictConfig,
3535
proxy_model: BenchmarkModel,

0 commit comments

Comments
 (0)