Skip to content
This repository was archived by the owner on Jan 21, 2025. It is now read-only.

Commit 0e7ce98

Browse files
William FedusMesh TensorFlow Team
authored andcommitted
Activation gradient clipping
PiperOrigin-RevId: 355449591
1 parent 7fb0df6 commit 0e7ce98

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

mesh_tensorflow/layers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2210,3 +2210,35 @@ def reversible_half_residual_and_swap(x1,
22102210
[x1, x1_backwards, x2, x2_backwards])
22112211
else:
22122212
return _half_residual_and_swap(x1, x1_backwards, x2, x2_backwards, f)
2213+
2214+
2215+
@gin.configurable
2216+
def clip_activation_gradient(x, clip_rms_norm=None):
2217+
"""Clip activation gradients by rms-norm."""
2218+
tf.logging.info("clip_activation_gradient.clip_rms_norm: {}".format(
2219+
clip_rms_norm))
2220+
2221+
def _reduce_rms(t):
2222+
return mtf.sqrt(mtf.reduce_mean(mtf.square(t)))
2223+
2224+
def forward_fn(x):
2225+
"""Identity forward pass."""
2226+
return mtf.identity(x)
2227+
2228+
def grad_fn(explicit_inputs, all_inputs, forward_operations, outputs,
2229+
output_grads):
2230+
del explicit_inputs, all_inputs, outputs, forward_operations
2231+
2232+
grad_ys = output_grads
2233+
if clip_rms_norm:
2234+
clipped_grad_ys = []
2235+
for g in grad_ys:
2236+
rms_norm = _reduce_rms(g)
2237+
clipping_denom = mtf.maximum(1.0, rms_norm / clip_rms_norm)
2238+
clipped_grad_ys.append(g / clipping_denom)
2239+
return clipped_grad_ys
2240+
return grad_ys
2241+
2242+
explicit_inputs = [x]
2243+
2244+
return mtf.custom_gradient(forward_fn, grad_fn, explicit_inputs)

mesh_tensorflow/transformer/transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,13 @@ def sublayer_dropout(x, layer_stack, context, dropout_rate=0.0):
565565
return x
566566

567567

568+
@gin.configurable
569+
def sublayer_clip_activation_gradient(x, layer_stack, context, rms_norm=1.0):
570+
"""Clip activation gradient by RMS-norm."""
571+
del layer_stack, context
572+
return mtf.layers.clip_activation_gradient(x, rms_norm)
573+
574+
568575
@gin.configurable
569576
def sublayer_legacy_dropout(x, layer_stack, context):
570577
return sublayer_dropout(x, layer_stack, context,

0 commit comments

Comments
 (0)