Skip to content

Commit 5e539a3

Browse files
Test that training=True activates dropout in a reusable SavedModel for BERT.
PiperOrigin-RevId: 307633374
1 parent 1f685c5 commit 5e539a3

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

official/nlp/bert/export_tfhub_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def test_export_tfhub(self):
8484
self.assertAllClose(source_output.numpy(), hub_output.numpy())
8585
self.assertAllClose(source_output.numpy(), encoder_output.numpy())
8686

87+
# Test that training=True makes a difference (activates dropout).
88+
def _dropout_mean_stddev(training, num_runs=20):
89+
input_ids = np.array([[14, 12, 42, 95, 99]], np.int32)
90+
inputs = [input_ids, np.ones_like(input_ids), np.zeros_like(input_ids)]
91+
outputs = np.concatenate(
92+
[hub_layer(inputs, training=training)[0] for _ in range(num_runs)])
93+
return np.mean(np.std(outputs, axis=0))
94+
self.assertLess(_dropout_mean_stddev(training=False), 1e-6)
95+
self.assertGreater(_dropout_mean_stddev(training=True), 1e-3)
96+
8797
# Test propagation of seq_length in shape inference.
8898
input_word_ids = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)
8999
input_mask = tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32)

0 commit comments

Comments
 (0)