Skip to content

Commit c20d719

Browse files
committed
fix(test): fix AWS_IO_DNS_INVALID_NAME in multiprocess tests
Add multiprocessing_context=mp.get_context() to DataLoader calls to ensure spawn method is used instead of fork in Darwin, preventing S3 client fork handlers from corrupting AWS CRT DNS resolver threads on macOS github runners.
1 parent 6146273 commit c20d719

File tree

3 files changed

+40
-10
lines changed

3 files changed

+40
-10
lines changed

s3torchconnector/tst/e2e/test_distributed_training.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,11 @@ def dataloader_for_map(
7979
)
8080
sampler = DistributedSampler(dataset)
8181
dataloader = DataLoader(
82-
dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler
82+
dataset,
83+
batch_size=batch_size,
84+
num_workers=num_workers,
85+
sampler=sampler,
86+
multiprocessing_context=mp.get_context(),
8387
)
8488
return dataloader
8589

@@ -93,7 +97,12 @@ def dataloader_for_iterable(
9397
enable_sharding=True,
9498
reader_constructor=reader_constructor,
9599
)
96-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
100+
dataloader = DataLoader(
101+
dataset,
102+
batch_size=batch_size,
103+
num_workers=num_workers,
104+
multiprocessing_context=mp.get_context(),
105+
)
97106
return dataloader
98107

99108

s3torchconnector/tst/e2e/test_e2e_s3_lightning_checkpoint.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from lightning.pytorch.demos import WikiText2
1313
from lightning.pytorch.plugins import AsyncCheckpointIO
1414
from torch.utils.data import DataLoader
15+
import torch.multiprocessing as mp
1516

1617
from s3torchconnector import S3Checkpoint
1718
from s3torchconnector._s3client import S3Client
@@ -79,7 +80,9 @@ def test_delete_checkpoint(checkpoint_directory):
7980
def test_load_trained_checkpoint(checkpoint_directory):
8081
nonce = random.randrange(2**64)
8182
dataset = WikiText2(data_dir=Path(f"/tmp/data/{nonce}"))
82-
dataloader = DataLoader(dataset, num_workers=3)
83+
dataloader = DataLoader(
84+
dataset, num_workers=3, multiprocessing_context=mp.get_context()
85+
)
8386
model = LightningTransformer(vocab_size=dataset.vocab_size)
8487
trainer = L.Trainer(accelerator=LIGHTNING_ACCELERATOR, fast_dev_run=2)
8588
trainer.fit(model=model, train_dataloaders=dataloader)
@@ -95,7 +98,9 @@ def test_load_trained_checkpoint(checkpoint_directory):
9598
def test_compatibility_with_trainer_plugins(checkpoint_directory):
9699
nonce = random.randrange(2**64)
97100
dataset = WikiText2(data_dir=Path(f"/tmp/data/{nonce}"))
98-
dataloader = DataLoader(dataset, num_workers=3)
101+
dataloader = DataLoader(
102+
dataset, num_workers=3, multiprocessing_context=mp.get_context()
103+
)
99104
model = LightningTransformer(vocab_size=dataset.vocab_size)
100105
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
101106
_verify_user_agent(s3_lightning_checkpoint)
@@ -121,7 +126,9 @@ def test_compatibility_with_trainer_plugins(checkpoint_directory):
121126
def test_compatibility_with_checkpoint_callback(checkpoint_directory):
122127
nonce = random.randrange(2**64)
123128
dataset = WikiText2(data_dir=Path(f"/tmp/data/{nonce}"))
124-
dataloader = DataLoader(dataset, num_workers=3)
129+
dataloader = DataLoader(
130+
dataset, num_workers=3, multiprocessing_context=mp.get_context()
131+
)
125132

126133
model = LightningTransformer(vocab_size=dataset.vocab_size)
127134
s3_lightning_checkpoint = S3LightningCheckpoint(checkpoint_directory.region)
@@ -161,7 +168,9 @@ def test_compatibility_with_checkpoint_callback(checkpoint_directory):
161168
def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
162169
nonce = random.randrange(2**64)
163170
dataset = WikiText2(data_dir=Path(f"/tmp/data/{nonce}"))
164-
dataloader = DataLoader(dataset, num_workers=3)
171+
dataloader = DataLoader(
172+
dataset, num_workers=3, multiprocessing_context=mp.get_context()
173+
)
165174

166175
model = LightningTransformer(vocab_size=dataset.vocab_size)
167176
s3_lightning_checkpoint = S3LightningCheckpoint(checkpoint_directory.region)
@@ -192,7 +201,9 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
192201
def test_compatibility_with_lightning_checkpoint_load(checkpoint_directory):
193202
nonce = random.randrange(2**64)
194203
dataset = WikiText2(data_dir=Path(f"/tmp/data/{nonce}"))
195-
dataloader = DataLoader(dataset, num_workers=3)
204+
dataloader = DataLoader(
205+
dataset, num_workers=3, multiprocessing_context=mp.get_context()
206+
)
196207
model = LightningTransformer(vocab_size=dataset.vocab_size)
197208
s3_lightning_checkpoint = S3LightningCheckpoint(region=checkpoint_directory.region)
198209
trainer = L.Trainer(

s3torchconnector/tst/e2e/test_multiprocess_dataloading.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111
from torch.utils.data import DataLoader, get_worker_info
1212
from torchdata.datapipes.iter import IterableWrapper
13+
import torch.multiprocessing as mp
1314

1415
from s3torchconnector import (
1516
S3IterableDataset,
@@ -85,7 +86,12 @@ def test_s3iterable_dataset_multiprocess_torchdata(
8586
batch_size = 2
8687
num_workers = 3
8788

88-
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
89+
dataloader = DataLoader(
90+
dataset,
91+
batch_size=batch_size,
92+
num_workers=num_workers,
93+
multiprocessing_context=mp.get_context(),
94+
)
8995

9096
total_objects = 0
9197
uris_seen = Counter()
@@ -123,7 +129,9 @@ def test_s3iterable_dataset_multiprocess(
123129
num_epochs = 2
124130
num_images = len(image_directory.contents)
125131

126-
dataloader = DataLoader(dataset, num_workers=num_workers)
132+
dataloader = DataLoader(
133+
dataset, num_workers=num_workers, multiprocessing_context=mp.get_context()
134+
)
127135
counter = 0
128136
for epoch in range(num_epochs):
129137
s3keys = Counter()
@@ -160,7 +168,9 @@ def test_s3mapdataset_multiprocess(
160168
num_epochs = 2
161169
num_images = len(image_directory.contents)
162170

163-
dataloader = DataLoader(dataset, num_workers=num_workers)
171+
dataloader = DataLoader(
172+
dataset, num_workers=num_workers, multiprocessing_context=mp.get_context()
173+
)
164174

165175
for epoch in range(num_epochs):
166176
s3keys = Counter()

0 commit comments

Comments
 (0)