1212from lightning .pytorch .demos import WikiText2
1313from lightning .pytorch .plugins import AsyncCheckpointIO
1414from torch .utils .data import DataLoader
15+ import torch .multiprocessing as mp
1516
1617from s3torchconnector import S3Checkpoint
1718from s3torchconnector ._s3client import S3Client
@@ -79,7 +80,9 @@ def test_delete_checkpoint(checkpoint_directory):
7980def test_load_trained_checkpoint (checkpoint_directory ):
8081 nonce = random .randrange (2 ** 64 )
8182 dataset = WikiText2 (data_dir = Path (f"/tmp/data/{ nonce } " ))
82- dataloader = DataLoader (dataset , num_workers = 3 )
83+ dataloader = DataLoader (
84+ dataset , num_workers = 3 , multiprocessing_context = mp .get_context ()
85+ )
8386 model = LightningTransformer (vocab_size = dataset .vocab_size )
8487 trainer = L .Trainer (accelerator = LIGHTNING_ACCELERATOR , fast_dev_run = 2 )
8588 trainer .fit (model = model , train_dataloaders = dataloader )
@@ -95,7 +98,9 @@ def test_load_trained_checkpoint(checkpoint_directory):
9598def test_compatibility_with_trainer_plugins (checkpoint_directory ):
9699 nonce = random .randrange (2 ** 64 )
97100 dataset = WikiText2 (data_dir = Path (f"/tmp/data/{ nonce } " ))
98- dataloader = DataLoader (dataset , num_workers = 3 )
101+ dataloader = DataLoader (
102+ dataset , num_workers = 3 , multiprocessing_context = mp .get_context ()
103+ )
99104 model = LightningTransformer (vocab_size = dataset .vocab_size )
100105 s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
101106 _verify_user_agent (s3_lightning_checkpoint )
@@ -121,7 +126,9 @@ def test_compatibility_with_trainer_plugins(checkpoint_directory):
121126def test_compatibility_with_checkpoint_callback (checkpoint_directory ):
122127 nonce = random .randrange (2 ** 64 )
123128 dataset = WikiText2 (data_dir = Path (f"/tmp/data/{ nonce } " ))
124- dataloader = DataLoader (dataset , num_workers = 3 )
129+ dataloader = DataLoader (
130+ dataset , num_workers = 3 , multiprocessing_context = mp .get_context ()
131+ )
125132
126133 model = LightningTransformer (vocab_size = dataset .vocab_size )
127134 s3_lightning_checkpoint = S3LightningCheckpoint (checkpoint_directory .region )
@@ -161,7 +168,9 @@ def test_compatibility_with_checkpoint_callback(checkpoint_directory):
161168def test_compatibility_with_async_checkpoint_io (checkpoint_directory ):
162169 nonce = random .randrange (2 ** 64 )
163170 dataset = WikiText2 (data_dir = Path (f"/tmp/data/{ nonce } " ))
164- dataloader = DataLoader (dataset , num_workers = 3 )
171+ dataloader = DataLoader (
172+ dataset , num_workers = 3 , multiprocessing_context = mp .get_context ()
173+ )
165174
166175 model = LightningTransformer (vocab_size = dataset .vocab_size )
167176 s3_lightning_checkpoint = S3LightningCheckpoint (checkpoint_directory .region )
@@ -192,7 +201,9 @@ def test_compatibility_with_async_checkpoint_io(checkpoint_directory):
192201def test_compatibility_with_lightning_checkpoint_load (checkpoint_directory ):
193202 nonce = random .randrange (2 ** 64 )
194203 dataset = WikiText2 (data_dir = Path (f"/tmp/data/{ nonce } " ))
195- dataloader = DataLoader (dataset , num_workers = 3 )
204+ dataloader = DataLoader (
205+ dataset , num_workers = 3 , multiprocessing_context = mp .get_context ()
206+ )
196207 model = LightningTransformer (vocab_size = dataset .vocab_size )
197208 s3_lightning_checkpoint = S3LightningCheckpoint (region = checkpoint_directory .region )
198209 trainer = L .Trainer (
0 commit comments