Skip to content

Implement layernorm eps and acc_aux_loss #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions moduleformer/configuration_moduleformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ class ModuleFormerConfig(PretrainedConfig):
Number of hidden layers in the Transformer encoder.
n_head (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer encoder.
rotary_dim (`int`, *optional*, defaults to 64):
Number of dimensions in the embedding that Rotary Position Embedding is applied to.
n_inner (`int`, *optional*, defaults to None):
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
Expand Down Expand Up @@ -109,6 +107,7 @@ def __init__(
tie_word_embeddings=False,
aux_loss_type = 'mi',
aux_loss_weight=0,
acc_aux_loss=False,
gate_type = "mlp",
**kwargs,
):
Expand All @@ -125,7 +124,7 @@ def __init__(
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.moe_pdrop = moe_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.layer_norm_epsilon = float(layer_norm_epsilon)
self.initializer_range = initializer_range
self.use_cache = use_cache
self.sample_topk = sample_topk
Expand All @@ -136,6 +135,7 @@ def __init__(
self.k_mlp = k_mlp
self.aux_loss_type = aux_loss_type
self.aux_loss_weight = aux_loss_weight
self.acc_aux_loss = acc_aux_loss
self.gate_type = gate_type
self.n_ctx = history_length * n_layer

Expand Down
10 changes: 5 additions & 5 deletions moduleformer/modeling_moduleformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, config):
head_size=config.att_hidden,
num_experts=config.n_att_experts,
top_k=config.k_att,
acc_aux_loss=False,
acc_aux_loss=config.acc_aux_loss,
bias=False,
gating_dropout=config.moe_pdrop,
sample_topk=config.sample_topk,
Expand Down Expand Up @@ -207,17 +207,17 @@ def __init__(self, config):
config: Configuration object with model hyperparameters.
"""
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = ModuleFormerAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.mlpf = MoE(
input_size=config.n_embd,
head_size=config.ffd_hidden,
num_experts=config.n_mlp_experts,
top_k=config.k_mlp,
bias=False,
activation=get_activation(config.activation_function),
acc_aux_loss=False,
acc_aux_loss=config.acc_aux_loss,
gating_dropout=config.moe_pdrop,
sample_topk=config.sample_topk,
gating_size=config.gating_size,
Expand Down Expand Up @@ -425,7 +425,7 @@ def __init__(self, config):
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([ModuleFormerBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)

# Initialize weights and apply final processing
self.post_init()
Expand Down
2 changes: 1 addition & 1 deletion moduleformer/utils/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MoE(nn.Module):
gating_dropout: a float - dropout rate for gating network
sample_topk: an integer - how many experts to sample during training
gating_size: an integer - size of the gating network
aux_loss: a string - type of auxiliary loss ('mi' or 'sparse')
aux_loss: a string - type of auxiliary loss ('mi' or 'switch')
gate_type: a string - type of gating mechanism ('mlp' or 'topk')
"""

Expand Down