Skip to content

Commit

Permalink
Explicitly check for OPT activation function (#3278)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Wyatt <[email protected]>
Co-authored-by: Jeff Rasley <[email protected]>
  • Loading branch information
3 people authored Apr 21, 2023
1 parent 145c3a7 commit 793c23e
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions deepspeed/module_inject/containers/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,18 +72,26 @@ class HFOPTLayerPolicy(TransformerPolicy):
_orig_layer_class = None

def __init__(self, client_module, inference=True, use_load_prefix=True):
super().__init__(inference,
linear_layer=True,
mlp_act_func_type=ActivationFuncType.ReLU,
pre_attn_norm=True,
use_load_prefix=use_load_prefix)
super().__init__(inference, linear_layer=True, pre_attn_norm=True, use_load_prefix=use_load_prefix)
self.client_module = client_module
try:
import transformers
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
except:
HFOPTLayerPolicy._orig_layer_class = None

if hasattr(TransformerPolicy, "hf_model_config") and hasattr(TransformerPolicy.hf_model_config,
"activation_function"):
if TransformerPolicy.hf_model_config.activation_function == "relu":
self.mlp_act_func_type == ActivationFuncType.ReLU
elif TransformerPolicy.hf_model_config.activation_function in ["gelu", "gelu_new"]:
self.mlp_act_func_type == ActivationFuncType.GELU
else:
raise ValueError("Unsupported activation function: {}".format(
TransformerPolicy.hf_model_config.activation_function))
else:
self.mlp_act_func_type == ActivationFuncType.ReLU # default

def get_hidden_heads(self):
return self.client_module.self_attn.embed_dim, \
self.client_module.self_attn.num_heads, \
Expand Down

0 comments on commit 793c23e

Please sign in to comment.