@@ -1131,6 +1131,7 @@ def my_fn(x):
11311131
11321132def dense_relu_dense (x ,
11331133 hidden_channels ,
1134+ is_training ,
11341135 dropout = 0.0 ,
11351136 dropout_broadcast_dims = None ,
11361137 master_dtype = tf .float32 ,
@@ -1142,6 +1143,7 @@ def dense_relu_dense(x,
11421143 Args:
11431144 x: a mtf.Tensor
11441145 hidden_channels: a mtf.Dimension - channels in the hidden layer
1146+ is_training: a boolean, set to true while training
11451147 dropout: an optional float
11461148 dropout_broadcast_dims: an optional list of mtf.Dimension
11471149 master_dtype: a tf.dtype
@@ -1156,9 +1158,8 @@ def dense_relu_dense(x,
11561158 h = dense (x , hidden_channels ,
11571159 use_bias = False , activation = mtf .relu ,
11581160 master_dtype = master_dtype , slice_dtype = slice_dtype , name = "wi" )
1159- if dropout != 0.0 :
1160- h = mtf .dropout (h , 1.0 - dropout ,
1161- noise_shape = h .shape - dropout_broadcast_dims )
1161+ h = mtf .dropout (h , is_training , 1.0 - dropout ,
1162+ noise_shape = h .shape - dropout_broadcast_dims )
11621163 return dense (h , io_channels , use_bias = False , activation = None ,
11631164 master_dtype = master_dtype , slice_dtype = slice_dtype ,
11641165 name = "wo" )
@@ -1187,6 +1188,7 @@ def local_self_attention_spatial_blocks(
11871188 query_antecedent ,
11881189 kv_channels ,
11891190 heads ,
1191+ is_training ,
11901192 memory_w_dim = None ,
11911193 mask_right = False ,
11921194 master_dtype = tf .float32 ,
@@ -1205,6 +1207,7 @@ def local_self_attention_spatial_blocks(
12051207 must have the same size as query_length, but a different name.
12061208 kv_channels: a mtf.Dimension (the size of the key and value vectors)
12071209 heads: a mtf.Dimension (the number of heads)
1210+ is_training: a bool, is true if training, else false.
12081211 memory_w_dim: mtf Dimension, for the memory width block.
12091212 mask_right: bool, flag specifying whether we mask out attention to the right
12101213 for the decoder.
@@ -1255,7 +1258,7 @@ def local_self_attention_spatial_blocks(
12551258 mask = attention_bias_local_block (
12561259 query_antecedent .mesh , w_dim , memory_w_dim )
12571260
1258- output = dot_product_attention (q , k , v , mask = mask )
1261+ output = dot_product_attention (q , k , v , mask = mask , is_training = is_training )
12591262
12601263 return mtf .einsum (
12611264 [output , wo ], mtf .Shape ([batch , num_w_blocks , w_dim , io_channels ]))
@@ -1264,6 +1267,7 @@ def local_self_attention_spatial_blocks(
12641267def masked_local_attention_1d (x ,
12651268 kv_channels ,
12661269 heads ,
1270+ is_training ,
12671271 window_size = 128 ,
12681272 master_dtype = tf .float32 ,
12691273 slice_dtype = tf .float32 ,
@@ -1280,6 +1284,7 @@ def masked_local_attention_1d(x,
12801284 x: a mtf.Tensor with shape batch_dims + [length, io_channels]
12811285 kv_channels: a mtf.Dimension (the size of the key and value vectors)
12821286 heads: a mtf.Dimension (the number of heads)
1287+ is_training: a bool, is True if training else False.
12831288 window_size: an integer
12841289 master_dtype: a tf.dtype (deprecated - use params arg)
12851290 slice_dtype: a tf.dtype (deprecated - use params arg)
@@ -1351,7 +1356,7 @@ def masked_local_attention_1d(x,
13511356 # Note: The first window_size-1 positions can see back into pre-time
13521357 # where all the keys and values are zero. We could mask this out, but we
13531358 # don't.
1354- o = dot_product_attention (q , k , v , mask = mask )
1359+ o = dot_product_attention (q , k , v , mask = mask , is_training = is_training )
13551360 o = mtf .reshape (o , batch_dims + [heads , length , kv_channels ])
13561361 return mtf .einsum ([o , wo ], mtf .Shape (batch_dims + [length , io_channels ]))
13571362
@@ -1408,7 +1413,7 @@ def masked_local_attention_1d_incremental(x,
14081413 mtf .mod (step_num , window_length .size ))
14091414 k = mtf .where (current_position , k , prev_k , output_shape = prev_k .shape )
14101415 v = mtf .where (current_position , v , prev_v , output_shape = prev_v .shape )
1411- o = dot_product_attention (q , k , v , mask = None )
1416+ o = dot_product_attention (q , k , v , mask = None , is_training = False )
14121417 y = mtf .einsum ([o , wo ], x .shape )
14131418 return y , k , v
14141419
@@ -1441,6 +1446,7 @@ def local_2d_halo_exchange(k, v, num_h_blocks, h_dim,
14411446def local_2d_self_attention_spatial_blocks (query_antecedent ,
14421447 kv_channels ,
14431448 heads ,
1449+ is_training ,
14441450 memory_h_dim = None ,
14451451 memory_w_dim = None ,
14461452 mask_right = False ,
@@ -1460,6 +1466,7 @@ def local_2d_self_attention_spatial_blocks(query_antecedent,
14601466 query_length, but a different name.
14611467 kv_channels: a mtf.Dimension (the size of the key and value vectors)
14621468 heads: a mtf.Dimension (the number of heads)
1469+ is_training: a bool, is True while training else False.
14631470 memory_h_dim: mtf Dimension, for the memory height block.
14641471 memory_w_dim: mtf Dimension, for the memory width block.
14651472 mask_right: bool, flag specifying whether we mask out attention to the right
@@ -1515,7 +1522,7 @@ def local_2d_self_attention_spatial_blocks(query_antecedent,
15151522 mask = attention_bias_local_2d_block (query_antecedent .mesh , h_dim , w_dim ,
15161523 memory_h_dim , memory_w_dim )
15171524
1518- output = dot_product_attention (q , k , v , mask = mask )
1525+ output = dot_product_attention (q , k , v , mask = mask , is_training = is_training )
15191526
15201527 return mtf .einsum (
15211528 [output , wo ],
@@ -1592,6 +1599,7 @@ def dot_product_attention(q,
15921599 k ,
15931600 v ,
15941601 mask ,
1602+ is_training ,
15951603 dropout = 0.0 ,
15961604 dropout_broadcast_dims = None ,
15971605 extra_logit = None ):
@@ -1605,6 +1613,7 @@ def dot_product_attention(q,
16051613 v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must
16061614 match with q.
16071615 mask: mask Tensor (see attention_mask())
1616+ is_training: a boolean, set to true while training
16081617 dropout: a float.
16091618 dropout_broadcast_dims: an optional list of mtf.Dimension
16101619 extra_logit: an optional scalar or tensor
@@ -1618,10 +1627,9 @@ def dot_product_attention(q,
16181627 if mask is not None :
16191628 logits += mask
16201629 weights = mtf .softmax (logits , length_kv , extra_logit = extra_logit )
1621- if dropout != 0.0 :
1622- weights = mtf .dropout (
1623- weights , 1.0 - dropout ,
1624- noise_shape = weights .shape - dropout_broadcast_dims )
1630+ weights = mtf .dropout (
1631+ weights , is_training , 1.0 - dropout ,
1632+ noise_shape = weights .shape - dropout_broadcast_dims )
16251633 depth_v = v .shape .dims [- 1 ]
16261634 outputs_shape = mtf .Shape (q .shape .dims [:- 1 ] + [depth_v ])
16271635 outputs = mtf .einsum ([weights , v ], outputs_shape )
@@ -1633,6 +1641,7 @@ def multihead_attention(query_antecedent,
16331641 mask ,
16341642 kv_channels ,
16351643 heads ,
1644+ is_training ,
16361645 dropout = 0.0 ,
16371646 dropout_broadcast_dims = None ,
16381647 master_dtype = tf .float32 ,
@@ -1653,6 +1662,7 @@ def multihead_attention(query_antecedent,
16531662 mask: mask Tensor (see attention_mask())
16541663 kv_channels: a mtf.Dimension (the size of the key and value vectors)
16551664 heads: a mtf.Dimension (the number of heads)
1665+ is_training: a bool, is True while training, false otherwise.
16561666 dropout: a floating point value
16571667 dropout_broadcast_dims: an optional list of mtf.Dimension
16581668 master_dtype: a tf.dtype
@@ -1692,7 +1702,7 @@ def multihead_attention(query_antecedent,
16921702 [memory_antecedent , wv ],
16931703 mtf .Shape (batch_dims + [heads , memory_length , kv_channels ]))
16941704 o = dot_product_attention (
1695- q , k , v , mask , dropout , dropout_broadcast_dims )
1705+ q , k , v , mask , is_training , dropout , dropout_broadcast_dims )
16961706 return mtf .einsum (
16971707 [o , wo ], mtf .Shape (batch_dims + [query_length , io_channels ]))
16981708
@@ -1756,7 +1766,7 @@ def multihead_self_attention_incremental(query_antecedent,
17561766 mtf .greater (mtf .range (
17571767 query_antecedent .mesh , memory_length , dtype = tf .int32 ), step_num ),
17581768 q .dtype ) * - 1e9
1759- o = dot_product_attention (q , k , v , mask )
1769+ o = dot_product_attention (q , k , v , mask , is_training = False )
17601770 y = mtf .einsum ([o , wo ], query_antecedent .shape )
17611771 return y , k , v
17621772
@@ -1792,7 +1802,7 @@ def multihead_encdec_attention_incremental(query_antecedent,
17921802 q = mtf .einsum (
17931803 [query_antecedent , wq ],
17941804 mtf .Shape (query_dims + [heads , kv_channels ]))
1795- o = dot_product_attention (q , k , v , mask )
1805+ o = dot_product_attention (q , k , v , mask , is_training = False )
17961806 return mtf .einsum ([o , wo ], query_antecedent .shape )
17971807
17981808
@@ -1931,6 +1941,7 @@ def multihead_self_attention_memory_compressed(x,
19311941 compression_factor ,
19321942 kv_channels ,
19331943 heads ,
1944+ is_training ,
19341945 dropout = 0.0 ,
19351946 dropout_broadcast_dims = None ,
19361947 master_dtype = tf .float32 ,
@@ -1948,6 +1959,7 @@ def multihead_self_attention_memory_compressed(x,
19481959 compression_factor: an integer
19491960 kv_channels: a mtf.Dimension (the size of the key and value vectors)
19501961 heads: a mtf.Dimension (the number of heads)
1962+ is_training: a boolean, set to true while training
19511963 dropout: a floating point value
19521964 dropout_broadcast_dims: an optional list of mtf.Dimension
19531965 master_dtype: a tf.dtype
@@ -1989,7 +2001,8 @@ def multihead_self_attention_memory_compressed(x,
19892001 else :
19902002 mask = None
19912003 o = dot_product_attention (
1992- q , k , v , mask , dropout , dropout_broadcast_dims , extra_logit = 0.0 )
2004+ q , k , v , mask , is_training , dropout , dropout_broadcast_dims ,
2005+ extra_logit = 0.0 )
19932006 return mtf .einsum (
19942007 [o , wo ], mtf .Shape (batch_dims + [length , io_channels ]))
19952008
0 commit comments