-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Description
🐛 Describe the bug
When a custom Dataset
stores a Python module object (e.g. self.h5py = h5py
) as an attribute, DataLoader
with num_workers > 0
fails with
TypeError: cannot pickle 'module' object
.
This happens because DataLoader
uses multiprocessing with spawn
/fork
, which requires pickling the dataset. Module objects are not picklable, but the error message is cryptic and not clearly tied to the dataset.
To Reproduce
Code
import torch
from torch.utils.data import Dataset, DataLoader
import h5py # any module works, h5py is just an example
import torch
import os
class MyDataset(Dataset):
def __init__(self):
# ⚠ storing a module object on self
self.h5py = h5py
def __len__(self):
return 10
def __getitem__(self, idx):
return idx
def distribute_loader(loader):
return torch.utils.data.DataLoader(
loader.dataset,
batch_size=loader.batch_size // torch.distributed.get_world_size(),
sampler=torch.utils.data.distributed.DistributedSampler(
loader.dataset,
num_replicas=torch.distributed.get_world_size(),
rank=torch.distributed.get_rank(),
),
num_workers=loader.num_workers,
pin_memory=loader.pin_memory,
)
def main(rank):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(12355)
# initialize the process group
torch.distributed.init_process_group(
"nccl",
rank=rank,
world_size=4,
)
torch.cuda.set_device(rank)
torch.distributed.barrier()
ds = MyDataset()
loader = DataLoader(ds, batch_size=4, num_workers=2) # num_workers > 0 triggers the bug
ddp_loader = distribute_loader(loader)
for i, batch in enumerate(ddp_loader):
print(batch)
torch.distributed.destroy_process_group()
if __name__ == "__main__":
torch.multiprocessing.spawn(main, nprocs=4)
Error
Traceback (most recent call last):
File "test_pickle.py", line 19, in <module>
for batch in loader:
File ".../torch/utils/data/dataloader.py", line ...
TypeError: cannot pickle 'module' object
Expected behavior
Either:
- Provide a clearer error message (e.g. "Dataset objects must not store module objects; they cannot be pickled for multiprocessing"),
or - Allow safe serialization by ignoring unpicklable module attributes.
Additional context
Workaround: do not store modules on self
. For example:
# instead of
self.h5py = h5py
# just import h5py inside __getitem__ or methods
import h5py
Do you want me to also prepare a shorter minimal version (without h5py, e.g. storing self.os = __import__("os")
) so the repro doesn’t depend on h5py
? That might make it even cleaner for maintainers.
i found this problem in torchvision.dataset.PCAM
try:
import h5py
self.h5py = h5py
except ImportError:
raise RuntimeError(
"h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
)
Versions
python = 3.10.18
torch = 2.7.0+cu128
torchvision = 0.22.0+cu128