Skip to content

Commit 9a3792e

Browse files
author
The TensorFlow Datasets Authors
committed
Internal change
PiperOrigin-RevId: 721300128
1 parent 9969ce5 commit 9a3792e

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

tensorflow_datasets/core/dataset_builder_test.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -743,29 +743,40 @@ def test_get_file_spec(self):
743743
"dummy_dataset_with_configs/plus1/0.0.1/dummy_dataset_with_configs-test.tfrecord@1",
744744
)
745745

746-
def test_load_as_data_source(self):
746+
@parameterized.parameters(
747+
(
748+
file_adapters.FileFormat.ARRAY_RECORD,
749+
array_record.ArrayRecordDataSource,
750+
),
751+
)
752+
def test_load_as_data_source(self, file_format, data_source_type):
747753
data_dir = self.get_temp_dir()
748754
builder = DummyDatasetWithConfigs(
749755
data_dir=data_dir,
750756
config="plus1",
751-
file_format=file_adapters.FileFormat.ARRAY_RECORD,
757+
file_format=file_format,
752758
)
753759
builder.download_and_prepare()
754760

755761
data_source = builder.as_data_source()
756762
assert isinstance(data_source, dict)
757-
assert isinstance(data_source["train"], array_record.ArrayRecordDataSource)
758-
assert isinstance(data_source["test"], array_record.ArrayRecordDataSource)
763+
assert isinstance(data_source["train"], data_source_type)
764+
assert isinstance(data_source["test"], data_source_type)
759765
assert len(data_source["test"]) == 10
760766
assert data_source["test"][0]["x"] == 28
761767
assert len(data_source["train"]) == 20
762768
assert data_source["train"][0]["x"] == 7
763769

764770
data_source = builder.as_data_source(split="test")
765-
assert isinstance(data_source, array_record.ArrayRecordDataSource)
771+
assert isinstance(data_source, data_source_type)
766772
assert len(data_source) == 10
767773
assert data_source[0]["x"] == 28
768774

775+
data_source = builder.as_data_source(split="all")
776+
assert isinstance(data_source, data_source_type)
777+
assert len(data_source) == 30
778+
assert data_source[0]["x"] == 7
779+
769780
def test_load_as_data_source_alternative_file_format(self):
770781
data_dir = self.get_temp_dir()
771782
builder = DummyDatasetWithConfigs(

0 commit comments

Comments
 (0)