@@ -743,29 +743,40 @@ def test_get_file_spec(self):
743
743
"dummy_dataset_with_configs/plus1/0.0.1/dummy_dataset_with_configs-test.tfrecord@1" ,
744
744
)
745
745
746
- def test_load_as_data_source (self ):
746
+ @parameterized .parameters (
747
+ (
748
+ file_adapters .FileFormat .ARRAY_RECORD ,
749
+ array_record .ArrayRecordDataSource ,
750
+ ),
751
+ )
752
+ def test_load_as_data_source (self , file_format , data_source_type ):
747
753
data_dir = self .get_temp_dir ()
748
754
builder = DummyDatasetWithConfigs (
749
755
data_dir = data_dir ,
750
756
config = "plus1" ,
751
- file_format = file_adapters . FileFormat . ARRAY_RECORD ,
757
+ file_format = file_format ,
752
758
)
753
759
builder .download_and_prepare ()
754
760
755
761
data_source = builder .as_data_source ()
756
762
assert isinstance (data_source , dict )
757
- assert isinstance (data_source ["train" ], array_record . ArrayRecordDataSource )
758
- assert isinstance (data_source ["test" ], array_record . ArrayRecordDataSource )
763
+ assert isinstance (data_source ["train" ], data_source_type )
764
+ assert isinstance (data_source ["test" ], data_source_type )
759
765
assert len (data_source ["test" ]) == 10
760
766
assert data_source ["test" ][0 ]["x" ] == 28
761
767
assert len (data_source ["train" ]) == 20
762
768
assert data_source ["train" ][0 ]["x" ] == 7
763
769
764
770
data_source = builder .as_data_source (split = "test" )
765
- assert isinstance (data_source , array_record . ArrayRecordDataSource )
771
+ assert isinstance (data_source , data_source_type )
766
772
assert len (data_source ) == 10
767
773
assert data_source [0 ]["x" ] == 28
768
774
775
+ data_source = builder .as_data_source (split = "all" )
776
+ assert isinstance (data_source , data_source_type )
777
+ assert len (data_source ) == 30
778
+ assert data_source [0 ]["x" ] == 7
779
+
769
780
def test_load_as_data_source_alternative_file_format (self ):
770
781
data_dir = self .get_temp_dir ()
771
782
builder = DummyDatasetWithConfigs (
0 commit comments