@@ -164,6 +164,36 @@ def __init__(self,
164
164
attention_scores = all_attention_scores )
165
165
super ().__init__ (
166
166
inputs = self .inputs , outputs = outputs , ** kwargs )
167
+ self ._config = dict (
168
+ name = self .name ,
169
+ word_vocab_size = word_vocab_size ,
170
+ word_embed_size = word_embed_size ,
171
+ type_vocab_size = type_vocab_size ,
172
+ max_sequence_length = max_sequence_length ,
173
+ num_blocks = num_blocks ,
174
+ hidden_size = hidden_size ,
175
+ num_attention_heads = num_attention_heads ,
176
+ intermediate_size = intermediate_size ,
177
+ intermediate_act_fn = intermediate_act_fn ,
178
+ hidden_dropout_prob = hidden_dropout_prob ,
179
+ attention_probs_dropout_prob = attention_probs_dropout_prob ,
180
+ intra_bottleneck_size = intra_bottleneck_size ,
181
+ initializer_range = initializer_range ,
182
+ use_bottleneck_attention = use_bottleneck_attention ,
183
+ key_query_shared_bottleneck = key_query_shared_bottleneck ,
184
+ num_feedforward_networks = num_feedforward_networks ,
185
+ normalization_type = normalization_type ,
186
+ classifier_activation = classifier_activation ,
187
+ input_mask_dtype = input_mask_dtype ,
188
+ ** kwargs ,
189
+ )
190
+
191
+ def get_config (self ):
192
+ return dict (self ._config )
193
+
194
+ @classmethod
195
+ def from_config (cls , config ):
196
+ return cls (** config )
167
197
168
198
def get_embedding_table (self ):
169
199
return self .embedding_layer .word_embedding .embeddings
0 commit comments