Skip to content

Commit c01dedb

Browse files
committed
Update documents and test components.
1 parent 01e0651 commit c01dedb

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

docs/examples/plot_KimCNN_quickstart.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@
5656
model_name=model_name,
5757
network_config=network_config,
5858
classes=classes,
59-
word_dict=word_dict,
6059
embed_vecs=embed_vecs,
6160
learning_rate=learning_rate,
6261
monitor_metrics=["Micro-F1", "Macro-F1", "P@1", "P@3", "P@5"],
@@ -66,7 +65,7 @@
6665
# * ``model_name`` leads ``init_model`` function to find a network model.
6766
# * ``network_config`` contains the configurations of a network model.
6867
# * ``classes`` is the label set of the data.
69-
# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them.
68+
# * ``embed_vecs`` is the the pre-trained word vectors.
7069
# * ``moniter_metrics`` includes metrics you would like to track.
7170
#
7271
#

docs/examples/plot_bert_quickstart.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
# * ``model_name`` leads ``init_model`` function to find a network model.
7171
# * ``network_config`` contains the configurations of a network model.
7272
# * ``classes`` is the label set of the data.
73-
# * ``init_weight``, ``word_dict`` and ``embed_vecs`` are not used on a bert-base model, so we can ignore them.
7473
# * ``moniter_metrics`` includes metrics you would like to track.
7574
#
7675
#

tests/nn/components.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_name(self):
2020
return "token_to_id"
2121

2222
def get_from_trainer(self, trainer):
23-
return trainer.model.word_dict
23+
return trainer.word_dict
2424

2525
def compare(self, a, b):
2626
return a == b
@@ -34,7 +34,7 @@ def get_name(self):
3434
return "embed_vecs"
3535

3636
def get_from_trainer(self, trainer):
37-
return trainer.model.embed_vecs
37+
return trainer.embed_vecs
3838

3939
def compare(self, a, b):
4040
return (a == b).all()

0 commit comments

Comments
 (0)