diff --git a/mesh_tensorflow/transformer/moe.py b/mesh_tensorflow/transformer/moe.py index 910adf04..d6bc069e 100644 --- a/mesh_tensorflow/transformer/moe.py +++ b/mesh_tensorflow/transformer/moe.py @@ -25,6 +25,7 @@ from __future__ import division from __future__ import print_function +import math import gin import mesh_tensorflow as mtf @@ -65,7 +66,10 @@ def __init__(self, word_embed_mode=None, use_second_place_expert_prob=None, use_second_place_expert_prob_temp=None, - top_n_num_experts_per_token=3): + top_n_num_experts_per_token=3, + rloo=False, + loss_type="load_balance", + p_dot_e=True): self._hparams = HParams( moe_gating=moe_gating, moe_num_experts=num_experts, @@ -95,7 +99,10 @@ def __init__(self, use_second_place_expert_prob), moe_use_second_place_expert_prob_temp=( use_second_place_expert_prob_temp), - moe_top_n_num_experts_per_token=top_n_num_experts_per_token) + moe_top_n_num_experts_per_token=top_n_num_experts_per_token, + moe_rloo=rloo, + loss_type=loss_type, + p_dot_e=p_dot_e) self._activation = activation def call(self, context, x, losses=None): @@ -127,7 +134,8 @@ def call(self, context, x, losses=None): nonpadding=context.nonpadding, activation=self._activation, num_microbatches=context.num_microbatches, - token_embeddings=context.input_embeddings) + token_embeddings=context.input_embeddings, + context=context) if context.losses is not None: context.losses.append(loss) if not has_length_dim: @@ -202,7 +210,7 @@ def call(self, context, x, losses=None): def transformer_moe_layer_v1( inputs, output_dim, hparams, train, variable_dtype, layout=None, mesh_shape=None, nonpadding=None, activation=mtf.relu, - num_microbatches=None, token_embeddings=None): + num_microbatches=None, token_embeddings=None, context=None): """Local mixture of experts that works well on TPU. Adapted from the paper https://arxiv.org/abs/1701.06538 @@ -281,6 +289,8 @@ def transformer_moe_layer_v1( [batch_dim(s), length_dim, input_dim]. These are the word embeddings for that correspond to the inputs. These can optionally be used to make routing decisions. + context: a Context object contains extra information that layers need + at call time, as defined in transformer.py. Returns: outputs: a Tensor with shape [batch_dim(s), length_dim, output_dim] @@ -436,7 +446,8 @@ def transformer_moe_layer_v1( variable_dtype=variable_dtype, importance=nonpadding, num_microbatches=num_microbatches, - token_embeddings=token_embeddings) + token_embeddings=token_embeddings, + context=context) elif hparams.moe_gating == "ntlb": dispatch_tensor, combine_tensor, loss = _ntlb_gating( inputs=inputs, @@ -1303,7 +1314,8 @@ def _expert_selection_gating( def _switch_gating( inputs, outer_expert_dims, experts_dim, expert_capacity_dim, hparams, train, variable_dtype, importance=None, name="switch_gating", - num_microbatches=None, token_embeddings=None): + num_microbatches=None, token_embeddings=None, + context=None): """Compute Switch gating.""" # SELECT EXPERT if train: @@ -1351,6 +1363,11 @@ def _switch_gating( expert_gate = mtf.gather(raw_gates, expert_index, dim=experts_dim) else: raise ValueError("Unknown Switch gating policy %s" % policy) + full_expert_gate_log_probs = gate_logits / hparams.moe_switch_temperature + full_expert_gate_log_probs -= mtf.reduce_logsumexp(full_expert_gate_log_probs, + reduced_dim=experts_dim) + expert_gate_log_probs = mtf.gather(full_expert_gate_log_probs, expert_index, + dim=experts_dim) expert_mask = mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) @@ -1363,9 +1380,25 @@ def _switch_gating( expert_gate *= mtf.cast(mtf.equal(importance, 1.0), dtype=raw_gates.dtype) density_1_proxy *= mtf.cast( mtf.equal(importance, 1.0), dtype=raw_gates.dtype) - loss = ( + load_balance_loss = ( mtf.reduce_mean(density_1_proxy * density_1) * float(experts_dim.size * experts_dim.size)) + + kl_with_uniform = ( + - math.log(float(experts_dim.size)) + - mtf.reduce_logsumexp(full_expert_gate_log_probs, + reduced_dim=group_size_dim) + + math.log(float(group_size_dim.size))) + if importance: + kl_with_uniform *= mtf.cast(mtf.equal(importance, 1.0), + dtype=raw_gates.dtype) + kl_with_uniform = mtf.reduce_mean(kl_with_uniform) + + if hparams.loss_type.lower() == "kl": + loss = kl_with_uniform + else: + loss = load_balance_loss + if num_microbatches and num_microbatches > 1: tf.logging.info("Dividing load-balance loss by num_microbatches={}".format( num_microbatches)) @@ -1373,11 +1406,14 @@ def _switch_gating( # Logging if train: - entropy = mtf.reduce_sum(-raw_gates * mtf.log(raw_gates + 1e-9), - reduced_dim=experts_dim) + entropy = mtf.reduce_sum( + -mtf.exp(full_expert_gate_log_probs) * full_expert_gate_log_probs, + reduced_dim=experts_dim) batch_entropy = mtf.reduce_mean(entropy) mtf.scalar_summary(name + "/entropy", batch_entropy) mtf.scalar_summary("expert_gate", mtf.reduce_mean(expert_gate)) + mtf.scalar_summary("tempered_expert_gate", + mtf.reduce_mean(mtf.exp(expert_gate_log_probs))) mask_count_experts = mtf.reduce_sum(expert_mask, output_shape=[experts_dim]) total_routed = mtf.reduce_sum(mask_count_experts) @@ -1389,7 +1425,25 @@ def _switch_gating( for fraction in split_fractions: mtf.scalar_summary("experts/" + fraction.name.replace(":", "/"), mtf.reduce_mean(fraction)) - mtf.scalar_summary("aux_loss", mtf.reduce_mean(loss)) + dead_expert_fraction = mtf.reduce_mean( + mtf.cast(mtf.equal(mask_count_experts, 0.), + dtype=raw_gates.dtype)) + mtf.scalar_summary("dead_expert_fraction", + dead_expert_fraction) + mtf.scalar_summary("load_balancing_loss", + mtf.reduce_mean(load_balance_loss)) + mtf.scalar_summary("kl_with_uniform", + mtf.reduce_mean(kl_with_uniform)) + + split_expert_index = mtf.rename_dimension( + expert_index, 'batch', 'batch_split') + first_expert_index, second_expert_index = mtf.split( + split_expert_index, + split_expert_index.shape.get_dim_by_name('batch_split'), 2) + duplicate_sample = mtf.reduce_mean( + mtf.cast(mtf.equal(first_expert_index, second_expert_index), + dtype=raw_gates.dtype)) + mtf.scalar_summary("duplicate_sample_fraction", duplicate_sample) # Add in the z_loss for router. if train and hparams.moe_z_loss is not None: @@ -1421,9 +1475,16 @@ def _switch_gating( # Mask out the experts that have overflowed expert capacity. Sparsify the # expert_gate. expert_gate *= expert_mask_flat + if hparams.moe_rloo: + expert_gate_log_probs *= expert_mask_flat + context.expert_gate_log_probs.append(expert_gate_log_probs) - combine_tensor = ( - expert_gate * expert_mask_flat * + if hparams.p_dot_e: + combine_tensor = expert_gate + else: + combine_tensor = expert_mask_flat + + combine_tensor *= ( mtf.one_hot(expert_index, experts_dim, dtype=raw_gates.dtype) * mtf.one_hot( mtf.to_int32(position_in_expert), diff --git a/mesh_tensorflow/transformer/transformer.py b/mesh_tensorflow/transformer/transformer.py index df715b59..1f18eb39 100644 --- a/mesh_tensorflow/transformer/transformer.py +++ b/mesh_tensorflow/transformer/transformer.py @@ -144,7 +144,8 @@ def __init__(self, read_priority=None, inputs=None, encoder_inputs=None, - num_microbatches=1): + num_microbatches=1, + expert_gate_log_probs=None): """Create a context. Args: @@ -201,6 +202,8 @@ def __init__(self, decoder. num_microbatches: integer - greater than one if the step has been serialized into multiple microbatches to save memory. + expert_gate_log_probs: an optional list of Tensors of expert gate log + probs. This will be used to compute REINFORCE gradients. """ self.model = model self.mesh = mesh @@ -235,6 +238,7 @@ def __init__(self, self.encoder_inputs = encoder_inputs self.num_microbatches = num_microbatches self.input_embeddings = None + self.expert_gate_log_probs = expert_gate_log_probs @property def train(self): @@ -848,6 +852,19 @@ def _compute_loss(self, context, logits, targets, output_vocab_dim): if self.loss_on_targets_only: weights *= mtf.cast(mtf.logical_not(delimited_lm_inputs_mask(targets)), dtype=context.activation_dtype) + + # Compute REINFORCE loss + if context.expert_gate_log_probs: + log_probs = mtf.reshape( + mtf.add_n(context.expert_gate_log_probs), loss.shape) + split_loss = mtf.rename_dimension(loss, "batch", "batch_unsplit") + first_loss, second_loss = mtf.split( + split_loss, split_loss.shape.get_dim_by_name("batch_unsplit"), 2) + baseline = mtf.concat([second_loss, first_loss], "batch_unsplit") + baseline = mtf.rename_dimension(baseline, "batch_unsplit", "batch") + loss += mtf.stop_gradient(loss - baseline) * mtf.cast( + log_probs, loss.dtype) + return (mtf.reduce_sum(loss * weights) / self.loss_denominator(targets, context.num_microbatches)) @@ -1007,6 +1024,27 @@ def call_simple(self, logits: a Tensor with shape [, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ + if mode == tf.estimator.ModeKeys.TRAIN: + + def duplicate_batch(t, batch_dim_name="batch"): + if t: + # Assumes that the batch size is divisible by 2 + half_batch_size = t.shape.get_dim_by_name(batch_dim_name).size // 2 + t = mtf.rename_dimension(t, batch_dim_name, batch_dim_name + "_slice") + half_batch = mtf.slice(t, 0, half_batch_size, + batch_dim_name + "_slice") + t = mtf.concat([half_batch, half_batch], batch_dim_name + "_slice") + return mtf.rename_dimension(t, batch_dim_name + "_slice", + batch_dim_name) + else: + return t + + inputs = duplicate_batch(inputs) + targets = duplicate_batch(targets) + sequence_id = duplicate_batch(sequence_id) + position = duplicate_batch(position) + encoder_sequence_id = duplicate_batch(encoder_sequence_id) + batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32) @@ -1061,7 +1099,8 @@ def call_simple(self, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs, - num_microbatches=num_microbatches) + num_microbatches=num_microbatches, + expert_gate_log_probs=[],) with tf.variable_scope(self.name): logits = self._call_internal(context, inputs, targets) if compute_loss: