Skip to content

Commit 7337692

Browse files
authored
feat(dcp): Add thread_count parameter to S3StorageWriter (#370)
* The thread_count parameter is used and defaulted to 1 within FileSystemWriter (parent class). This change is to expose this parameter to our users through our S3StorageWriter. We currently use it already within e2e filesystem integration tests and DCP benchmarks. We have extended our unit tests and added docstrings to give better visibility to this parameter.
1 parent ffc6302 commit 7337692

File tree

4 files changed

+34
-1
lines changed

4 files changed

+34
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
* Override S3Writer closed property and block writes after close (#360)
66
* Fix SequentialS3Reader seek beyond EOF to clamp position to object size (#362)
77

8+
### Other changes
9+
* Added thread_count parameter to S3StorageWriter
10+
811
## v1.4.3 (July 25, 2025)
912

1013
### New features

README.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,11 @@ REGION = "us-east-1"
162162
model = torchvision.models.resnet18()
163163

164164
# Save distributed checkpoint to S3
165-
s3_storage_writer = S3StorageWriter(region=REGION, path=CHECKPOINT_URI)
165+
s3_storage_writer = S3StorageWriter(
166+
region=REGION,
167+
path=CHECKPOINT_URI,
168+
thread_count=8, # optional; number of IO threads to use to write
169+
)
166170
DCP.save(
167171
state_dict=model.state_dict(),
168172
storage_writer=s3_storage_writer,

s3torchconnector/src/s3torchconnector/dcp/s3_file_system.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def __init__(
271271
path: str,
272272
s3client_config: Optional[S3ClientConfig] = None,
273273
prefix_strategy: Optional[S3PrefixStrategyBase] = None,
274+
thread_count: int = 1,
274275
**kwargs,
275276
) -> None:
276277
"""
@@ -282,11 +283,13 @@ def __init__(
282283
s3client_config (Optional[S3ClientConfig]): Optional S3ClientConfig with parameters for S3 client.
283284
prefix_strategy (Optional[S3PrefixStrategyBase]): Optional strategy for generating S3 prefixes to
284285
optimize checkpoint organization and prevent throttling.
286+
thread_count (int): Number of IO threads to use to write. Defaults to 1 (Pytorch Default)
285287
kwargs (dict): Keyword arguments to pass to the parent :class:`FileSystemWriter`.
286288
"""
287289
super().__init__(
288290
path=path,
289291
sync_files=False, # FIXME: setting this to True makes the run to fail (L#333: `os.fsync(stream.fileno())`)
292+
thread_count=thread_count,
290293
**kwargs,
291294
)
292295
self.fs = S3FileSystem(region, s3client_config=s3client_config) # type: ignore
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
# // SPDX-License-Identifier: BSD
3+
4+
import pytest
5+
from s3torchconnector.dcp import S3StorageWriter
6+
7+
TEST_REGION = "eu-east-1"
8+
TEST_BUCKET = "test-bucket"
9+
TEST_KEY = "test-key.txt"
10+
TEST_PATH = f"s3://{TEST_BUCKET}/{TEST_KEY}"
11+
12+
13+
@pytest.mark.parametrize("thread_count", [1, 2, 4, 8, 16])
14+
def test_s3storage_writer_thread_count(thread_count):
15+
storage_writer = S3StorageWriter(
16+
region=TEST_REGION, path=TEST_PATH, thread_count=thread_count
17+
)
18+
assert storage_writer.thread_count == thread_count
19+
20+
21+
def test_s3storage_writer_thread_count_defaults_to_one():
22+
storage_writer = S3StorageWriter(region=TEST_REGION, path=TEST_PATH)
23+
assert storage_writer.thread_count == 1

0 commit comments

Comments
 (0)