Skip to content

Commit d8f81e7

Browse files
joezougtensorflower-gardener
authored andcommitted
No public description
PiperOrigin-RevId: 672821554
1 parent 2333570 commit d8f81e7

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

official/nlp/modeling/networks/mobile_bert_encoder.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,36 @@ def __init__(self,
164164
attention_scores=all_attention_scores)
165165
super().__init__(
166166
inputs=self.inputs, outputs=outputs, **kwargs)
167+
self._config = dict(
168+
name=self.name,
169+
word_vocab_size=word_vocab_size,
170+
word_embed_size=word_embed_size,
171+
type_vocab_size=type_vocab_size,
172+
max_sequence_length=max_sequence_length,
173+
num_blocks=num_blocks,
174+
hidden_size=hidden_size,
175+
num_attention_heads=num_attention_heads,
176+
intermediate_size=intermediate_size,
177+
intermediate_act_fn=intermediate_act_fn,
178+
hidden_dropout_prob=hidden_dropout_prob,
179+
attention_probs_dropout_prob=attention_probs_dropout_prob,
180+
intra_bottleneck_size=intra_bottleneck_size,
181+
initializer_range=initializer_range,
182+
use_bottleneck_attention=use_bottleneck_attention,
183+
key_query_shared_bottleneck=key_query_shared_bottleneck,
184+
num_feedforward_networks=num_feedforward_networks,
185+
normalization_type=normalization_type,
186+
classifier_activation=classifier_activation,
187+
input_mask_dtype=input_mask_dtype,
188+
**kwargs,
189+
)
190+
191+
def get_config(self):
192+
return dict(self._config)
193+
194+
@classmethod
195+
def from_config(cls, config):
196+
return cls(**config)
167197

168198
def get_embedding_table(self):
169199
return self.embedding_layer.word_embedding.embeddings

0 commit comments

Comments
 (0)