@@ -43,23 +43,23 @@ def test_invalid_args(self):
43
43
44
44
def test_one_device_strategy_cpu (self ):
45
45
ds = distribute_utils .get_distribution_strategy ('one_device' , num_gpus = 0 )
46
- self .assertEquals (ds .num_replicas_in_sync , 1 )
47
- self .assertEquals (len (ds .extended .worker_devices ), 1 )
46
+ self .assertEqual (ds .num_replicas_in_sync , 1 )
47
+ self .assertEqual (len (ds .extended .worker_devices ), 1 )
48
48
self .assertIn ('CPU' , ds .extended .worker_devices [0 ])
49
49
50
50
def test_one_device_strategy_gpu (self ):
51
51
ds = distribute_utils .get_distribution_strategy ('one_device' , num_gpus = 1 )
52
- self .assertEquals (ds .num_replicas_in_sync , 1 )
53
- self .assertEquals (len (ds .extended .worker_devices ), 1 )
52
+ self .assertEqual (ds .num_replicas_in_sync , 1 )
53
+ self .assertEqual (len (ds .extended .worker_devices ), 1 )
54
54
self .assertIn ('GPU' , ds .extended .worker_devices [0 ])
55
55
56
56
def test_mirrored_strategy (self ):
57
57
# CPU only.
58
58
_ = distribute_utils .get_distribution_strategy (num_gpus = 0 )
59
59
# 5 GPUs.
60
60
ds = distribute_utils .get_distribution_strategy (num_gpus = 5 )
61
- self .assertEquals (ds .num_replicas_in_sync , 5 )
62
- self .assertEquals (len (ds .extended .worker_devices ), 5 )
61
+ self .assertEqual (ds .num_replicas_in_sync , 5 )
62
+ self .assertEqual (len (ds .extended .worker_devices ), 5 )
63
63
for device in ds .extended .worker_devices :
64
64
self .assertIn ('GPU' , device )
65
65
@@ -105,12 +105,13 @@ def test_tpu_strategy(self):
105
105
ds , tf .distribute .TPUStrategy )
106
106
107
107
def test_invalid_strategy (self ):
108
- with self .assertRaisesRegexp (
109
- ValueError ,
110
- 'distribution_strategy must be a string but got: False. If' ):
108
+ with self .assertRaisesRegex (
109
+ ValueError , 'distribution_strategy must be a string but got: False. If'
110
+ ):
111
111
distribute_utils .get_distribution_strategy (False )
112
- with self .assertRaisesRegexp (
113
- ValueError , 'distribution_strategy must be a string but got: 1' ):
112
+ with self .assertRaisesRegex (
113
+ ValueError , 'distribution_strategy must be a string but got: 1'
114
+ ):
114
115
distribute_utils .get_distribution_strategy (1 )
115
116
116
117
def test_get_strategy_scope (self ):
0 commit comments