Skip to content

Commit 9ec6f6e

Browse files
saberkuntensorflower-gardener
authored andcommitted
Internal change
PiperOrigin-RevId: 320641255
1 parent 45ab8e7 commit 9ec6f6e

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

official/nlp/configs/encoders.py

+37-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
1818
Includes configurations and instantiation methods.
1919
"""
20-
2120
import dataclasses
21+
import gin
2222
import tensorflow as tf
2323

2424
from official.modeling import tf_utils
@@ -42,10 +42,43 @@ class TransformerEncoderConfig(base_config.Config):
4242
initializer_range: float = 0.02
4343

4444

45-
def instantiate_encoder_from_cfg(
46-
config: TransformerEncoderConfig) -> networks.TransformerEncoder:
45+
@gin.configurable
46+
def instantiate_encoder_from_cfg(config: TransformerEncoderConfig,
47+
encoder_cls=networks.TransformerEncoder):
4748
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
48-
encoder_network = networks.TransformerEncoder(
49+
if encoder_cls.__name__ == "EncoderScaffold":
50+
embedding_cfg = dict(
51+
vocab_size=config.vocab_size,
52+
type_vocab_size=config.type_vocab_size,
53+
hidden_size=config.hidden_size,
54+
seq_length=None,
55+
max_seq_length=config.max_position_embeddings,
56+
initializer=tf.keras.initializers.TruncatedNormal(
57+
stddev=config.initializer_range),
58+
dropout_rate=config.dropout_rate,
59+
)
60+
hidden_cfg = dict(
61+
num_attention_heads=config.num_attention_heads,
62+
intermediate_size=config.intermediate_size,
63+
intermediate_activation=tf_utils.get_activation(
64+
config.hidden_activation),
65+
dropout_rate=config.dropout_rate,
66+
attention_dropout_rate=config.attention_dropout_rate,
67+
kernel_initializer=tf.keras.initializers.TruncatedNormal(
68+
stddev=config.initializer_range),
69+
)
70+
kwargs = dict(
71+
embedding_cfg=embedding_cfg,
72+
hidden_cfg=hidden_cfg,
73+
num_hidden_instances=config.num_layers,
74+
pooled_output_dim=config.hidden_size,
75+
pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
76+
stddev=config.initializer_range))
77+
return encoder_cls(**kwargs)
78+
79+
if encoder_cls.__name__ != "TransformerEncoder":
80+
raise ValueError("Unknown encoder network class. %s" % str(encoder_cls))
81+
encoder_network = encoder_cls(
4982
vocab_size=config.vocab_size,
5083
hidden_size=config.hidden_size,
5184
num_layers=config.num_layers,

0 commit comments

Comments
 (0)