Skip to content

Commit 496c2d4

Browse files
marcenacpThe TensorFlow Datasets Authors
authored and
The TensorFlow Datasets Authors
committed
Make DatasetInfo pickable so that tfds.data_source works in multiprocessing externally.
PiperOrigin-RevId: 638272983
1 parent 4fd9a5d commit 496c2d4

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

tensorflow_datasets/core/data_sources/base_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
"""Tests for all data sources."""
1717

18+
import pickle
1819
from unittest import mock
1920

21+
import cloudpickle
2022
from etils import epath
2123
import pytest
2224
import tensorflow_datasets as tfds
@@ -181,3 +183,18 @@ def test_data_source_is_sliceable():
181183
file_instructions = mock_array_record_data_source.call_args_list[1].args[0]
182184
assert file_instructions[0].skip == 0
183185
assert file_instructions[0].take == 30000
186+
187+
188+
# PyGrain requires that data sources are picklable.
189+
@pytest.mark.parametrize(
190+
'file_format',
191+
file_adapters.FileFormat.with_random_access(),
192+
)
193+
@pytest.mark.parametrize('pickle_module', [pickle, cloudpickle])
194+
def test_data_source_is_picklable_after_use(file_format, pickle_module):
195+
with tfds.testing.tmp_dir() as data_dir:
196+
builder = tfds.testing.DummyDataset(data_dir=data_dir)
197+
builder.download_and_prepare(file_format=file_format)
198+
data_source = builder.as_data_source(split='train')
199+
assert data_source[0] == {'id': 0}
200+
assert pickle_module.loads(pickle_module.dumps(data_source))[0] == {'id': 0}

tensorflow_datasets/core/dataset_info.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class DatasetInfo(object):
173173
"""
174174

175175
def __init__(
176+
# LINT.IfChange(dataset_info_args)
176177
self,
177178
*,
178179
builder: Union[DatasetIdentity, Any],
@@ -186,6 +187,7 @@ def __init__(
186187
license: Optional[str] = None, # pylint: disable=redefined-builtin
187188
redistribution_info: Optional[Dict[str, str]] = None,
188189
split_dict: Optional[splits_lib.SplitDict] = None,
190+
# LINT.ThenChange(:setstate)
189191
):
190192
# pyformat: disable
191193
"""Constructs DatasetInfo.
@@ -864,6 +866,35 @@ def __repr__(self):
864866
lines.append(")")
865867
return "\n".join(lines)
866868

869+
def __getstate__(self):
870+
return {
871+
"builder": self._builder_or_identity,
872+
"description": self.description,
873+
"features": self.features,
874+
"supervised_keys": self.supervised_keys,
875+
"disable_shuffling": self.disable_shuffling,
876+
"homepage": self.homepage,
877+
"citation": self.citation,
878+
"metadata": self.metadata,
879+
"license": self.redistribution_info.license,
880+
"split_dict": self.splits,
881+
}
882+
def __setstate__(self, state):
883+
# LINT.IfChange(setstate)
884+
self.__init__(
885+
builder=state["builder"],
886+
description=state["description"],
887+
features=state["features"],
888+
supervised_keys=state["supervised_keys"],
889+
disable_shuffling=state["disable_shuffling"],
890+
homepage=state["homepage"],
891+
citation=state["citation"],
892+
metadata=state["metadata"],
893+
license=state["license"],
894+
split_dict=state["split_dict"],
895+
)
896+
# LINT.ThenChange(:dataset_info_args)
897+
867898

868899
def _nest_to_proto(nest: Nest) -> dataset_info_pb2.SupervisedKeys.Nest:
869900
"""Creates a `SupervisedKeys.Nest` from a limited `tf.nest` style structure.

tensorflow_datasets/testing/mocking_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,12 @@ def test_as_data_source_fn():
392392
assert imagenet[0] == 'foo'
393393
assert imagenet[1] == 'bar'
394394
assert imagenet[2] == 'baz'
395+
396+
397+
# PyGrain requires that data sources are picklable.
398+
def test_mocked_data_source_is_pickable():
399+
with tfds.testing.mock_data(num_examples=2):
400+
data_source = tfds.data_source('imagenet2012', split='train')
401+
pickled_and_unpickled_data_source = pickle.loads(pickle.dumps(data_source))
402+
assert len(pickled_and_unpickled_data_source) == 2
403+
assert isinstance(pickled_and_unpickled_data_source[0]['image'], np.ndarray)

0 commit comments

Comments
 (0)