Skip to content

DataLoader multiprocessing fails when dataset stores a module object (TypeError: cannot pickle 'module' object) #9195

@nehdiii

Description

@nehdiii

🐛 Describe the bug

When a custom Dataset stores a Python module object (e.g. self.h5py = h5py) as an attribute, DataLoader with num_workers > 0 fails with
TypeError: cannot pickle 'module' object.

This happens because DataLoader uses multiprocessing with spawn/fork, which requires pickling the dataset. Module objects are not picklable, but the error message is cryptic and not clearly tied to the dataset.


To Reproduce

Code

import torch
from torch.utils.data import Dataset, DataLoader
import h5py  # any module works, h5py is just an example
import torch 
import os 



class MyDataset(Dataset):
    def __init__(self):
        # ⚠ storing a module object on self
        self.h5py = h5py

    def __len__(self):
        return 10

    def __getitem__(self, idx):
        return idx



def distribute_loader(loader):
    return torch.utils.data.DataLoader(
        loader.dataset,
        batch_size=loader.batch_size // torch.distributed.get_world_size(),
        sampler=torch.utils.data.distributed.DistributedSampler(
            loader.dataset,
            num_replicas=torch.distributed.get_world_size(),
            rank=torch.distributed.get_rank(),
        ),
        num_workers=loader.num_workers,
        pin_memory=loader.pin_memory,
    )

def main(rank):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(12355)

    # initialize the process group
    torch.distributed.init_process_group(
        "nccl",
        rank=rank,
        world_size=4,
    )
    torch.cuda.set_device(rank)
    torch.distributed.barrier()

    ds = MyDataset()
    loader = DataLoader(ds, batch_size=4, num_workers=2)  # num_workers > 0 triggers the bug
    ddp_loader = distribute_loader(loader)
    for i, batch in enumerate(ddp_loader):
        print(batch)

    torch.distributed.destroy_process_group()

if __name__ == "__main__":
    torch.multiprocessing.spawn(main, nprocs=4)

Error

Traceback (most recent call last):
  File "test_pickle.py", line 19, in <module>
    for batch in loader:
  File ".../torch/utils/data/dataloader.py", line ...
TypeError: cannot pickle 'module' object

Expected behavior

Either:

  1. Provide a clearer error message (e.g. "Dataset objects must not store module objects; they cannot be pickled for multiprocessing"),
    or
  2. Allow safe serialization by ignoring unpicklable module attributes.

Additional context

Workaround: do not store modules on self. For example:

# instead of
self.h5py = h5py

# just import h5py inside __getitem__ or methods
import h5py

Do you want me to also prepare a shorter minimal version (without h5py, e.g. storing self.os = __import__("os")) so the repro doesn’t depend on h5py? That might make it even cleaner for maintainers.

i found this problem in torchvision.dataset.PCAM

 try:
            import h5py

            self.h5py = h5py
        except ImportError:
            raise RuntimeError(
                "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
            )

Versions

python = 3.10.18
torch = 2.7.0+cu128
torchvision = 0.22.0+cu128

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions