17
17
18
18
Includes configurations and instantiation methods.
19
19
"""
20
-
21
20
import dataclasses
21
+ import gin
22
22
import tensorflow as tf
23
23
24
24
from official .modeling import tf_utils
@@ -42,10 +42,43 @@ class TransformerEncoderConfig(base_config.Config):
42
42
initializer_range : float = 0.02
43
43
44
44
45
- def instantiate_encoder_from_cfg (
46
- config : TransformerEncoderConfig ) -> networks .TransformerEncoder :
45
+ @gin .configurable
46
+ def instantiate_encoder_from_cfg (config : TransformerEncoderConfig ,
47
+ encoder_cls = networks .TransformerEncoder ):
47
48
"""Instantiate a Transformer encoder network from TransformerEncoderConfig."""
48
- encoder_network = networks .TransformerEncoder (
49
+ if encoder_cls .__name__ == "EncoderScaffold" :
50
+ embedding_cfg = dict (
51
+ vocab_size = config .vocab_size ,
52
+ type_vocab_size = config .type_vocab_size ,
53
+ hidden_size = config .hidden_size ,
54
+ seq_length = None ,
55
+ max_seq_length = config .max_position_embeddings ,
56
+ initializer = tf .keras .initializers .TruncatedNormal (
57
+ stddev = config .initializer_range ),
58
+ dropout_rate = config .dropout_rate ,
59
+ )
60
+ hidden_cfg = dict (
61
+ num_attention_heads = config .num_attention_heads ,
62
+ intermediate_size = config .intermediate_size ,
63
+ intermediate_activation = tf_utils .get_activation (
64
+ config .hidden_activation ),
65
+ dropout_rate = config .dropout_rate ,
66
+ attention_dropout_rate = config .attention_dropout_rate ,
67
+ kernel_initializer = tf .keras .initializers .TruncatedNormal (
68
+ stddev = config .initializer_range ),
69
+ )
70
+ kwargs = dict (
71
+ embedding_cfg = embedding_cfg ,
72
+ hidden_cfg = hidden_cfg ,
73
+ num_hidden_instances = config .num_layers ,
74
+ pooled_output_dim = config .hidden_size ,
75
+ pooler_layer_initializer = tf .keras .initializers .TruncatedNormal (
76
+ stddev = config .initializer_range ))
77
+ return encoder_cls (** kwargs )
78
+
79
+ if encoder_cls .__name__ != "TransformerEncoder" :
80
+ raise ValueError ("Unknown encoder network class. %s" % str (encoder_cls ))
81
+ encoder_network = encoder_cls (
49
82
vocab_size = config .vocab_size ,
50
83
hidden_size = config .hidden_size ,
51
84
num_layers = config .num_layers ,
0 commit comments