@@ -93,17 +93,21 @@ def _model(self):
9393
9494 def test_default_job_spec (self ):
9595 self .assertStartsWith (self ._job_spec ["job_id" ], "cloud_fit_" )
96- self . assertDictContainsSubset (
97- {
98- "masterConfig " : { "imageUri" : self . _image_uri ,},
99- "args" : [
100- "--remote_dir" ,
101- self . _remote_dir ,
102- "--distribution_strategy" ,
103- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
104- ],
105- },
96+ expected = {
97+ "masterConfig" : { "imageUri" : self . _image_uri ,},
98+ "args " : [
99+ "--remote_dir" ,
100+ self . _remote_dir ,
101+ "--distribution_strategy" ,
102+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
103+ ] ,
104+ }
105+ self . assertEqual (
106106 self ._job_spec ["trainingInput" ],
107+ {
108+ ** self ._job_spec ["trainingInput" ],
109+ ** expected ,
110+ }
107111 )
108112
109113 @mock .patch .object (discovery , "build" , autospec = True )
@@ -125,17 +129,21 @@ def test_submit_job(self, mock_discovery_build):
125129
126130 _ , fit_kwargs = list (self ._mock_create .call_args )
127131 body = fit_kwargs ["body" ]
128- self . assertDictContainsSubset (
129- {
130- "masterConfig " : { "imageUri" : self . _image_uri ,},
131- "args" : [
132- "--remote_dir" ,
133- self . _remote_dir ,
134- "--distribution_strategy" ,
135- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
136- ],
137- },
132+ expected = {
133+ "masterConfig" : { "imageUri" : self . _image_uri ,},
134+ "args " : [
135+ "--remote_dir" ,
136+ self . _remote_dir ,
137+ "--distribution_strategy" ,
138+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
139+ ] ,
140+ }
141+ self . assertEqual (
138142 body ["trainingInput" ],
143+ {
144+ ** body ["trainingInput" ],
145+ ** expected ,
146+ }
139147 )
140148 self .assertStartsWith (body ["job_id" ], "cloud_fit_" )
141149 self ._mock_get .execute .assert_called_with ()
@@ -212,8 +220,9 @@ def test_fit_kwargs(self, mock_submit_job):
212220 os .path .join (remote_dir , "training_assets" )
213221 )
214222 elements = training_assets_graph .fit_kwargs_fn ()
215- self .assertDictContainsSubset (tfds .as_numpy (
216- elements ), {"batch_size" : 1 , "epochs" : 2 , "verbose" : 3 })
223+ actual = {"batch_size" : 1 , "epochs" : 2 , "verbose" : 3 }
224+ expected = tfds .as_numpy (elements )
225+ self .assertEqual (actual , {** actual , ** expected })
217226
218227 @mock .patch .object (client , "_submit_job" , autospec = True )
219228 def test_custom_job_spec (self , mock_submit_job ):
@@ -245,17 +254,21 @@ def test_custom_job_spec(self, mock_submit_job):
245254
246255 kargs , _ = mock_submit_job .call_args
247256 body , _ = kargs
248- self . assertDictContainsSubset (
249- {
250- "masterConfig " : { "imageUri" : self . _image_uri ,},
251- "args" : [
252- "--remote_dir" ,
253- self . _remote_dir ,
254- "--distribution_strategy" ,
255- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
256- ],
257- },
257+ expected = {
258+ "masterConfig" : { "imageUri" : self . _image_uri ,},
259+ "args " : [
260+ "--remote_dir" ,
261+ self . _remote_dir ,
262+ "--distribution_strategy" ,
263+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
264+ ] ,
265+ }
266+ self . assertEqual (
258267 body ["trainingInput" ],
268+ {
269+ ** body ["trainingInput" ],
270+ ** expected ,
271+ }
259272 )
260273
261274 @mock .patch .object (client , "_submit_job" , autospec = True )
@@ -275,16 +288,20 @@ def test_distribution_strategy(
275288
276289 kargs , _ = mock_submit_job .call_args
277290 body , _ = kargs
278- self . assertDictContainsSubset (
279- {
280- "args" : [
281- "--remote_dir" ,
282- self . _remote_dir ,
283- "--distribution_strategy" ,
284- MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
285- ],
286- },
291+ expected = {
292+ "args" : [
293+ "--remote_dir" ,
294+ self . _remote_dir ,
295+ "--distribution_strategy" ,
296+ MULTI_WORKER_MIRRORED_STRATEGY_NAME ,
297+ ] ,
298+ }
299+ self . assertEqual (
287300 body ["trainingInput" ],
301+ {
302+ ** body ["trainingInput" ],
303+ ** expected ,
304+ }
288305 )
289306
290307 client .cloud_fit (
@@ -297,16 +314,20 @@ def test_distribution_strategy(
297314
298315 kargs , _ = mock_submit_job .call_args
299316 body , _ = kargs
300- self . assertDictContainsSubset (
301- {
302- "args" : [
303- "--remote_dir" ,
304- self . _remote_dir ,
305- "--distribution_strategy" ,
306- MIRRORED_STRATEGY_NAME ,
307- ],
308- },
317+ expected = {
318+ "args" : [
319+ "--remote_dir" ,
320+ self . _remote_dir ,
321+ "--distribution_strategy" ,
322+ MIRRORED_STRATEGY_NAME ,
323+ ] ,
324+ }
325+ self . assertEqual (
309326 body ["trainingInput" ],
327+ {
328+ ** body ["trainingInput" ],
329+ ** expected ,
330+ }
310331 )
311332
312333 with self .assertRaises (ValueError ):
@@ -351,7 +372,8 @@ def test_job_id(self, mock_serialize_assets, mock_submit_job):
351372
352373 kargs , _ = mock_submit_job .call_args
353374 body , _ = kargs
354- self .assertDictContainsSubset ({"job_id" : test_job_id ,}, body )
375+ expected = {"job_id" : test_job_id ,}
376+ self .assertEqual (body , {** body , ** expected })
355377
356378
357379if __name__ == "__main__" :
0 commit comments