forked from explosion/spaCy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_issue3611.py
51 lines (42 loc) · 1.59 KB
/
test_issue3611.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# coding: utf8
from __future__ import unicode_literals
import spacy
from spacy.util import minibatch, compounding
def test_issue3611():
""" Test whether adding n-grams in the textcat works even when n > token length of some docs """
unique_classes = ["offensive", "inoffensive"]
x_train = [
"This is an offensive text",
"This is the second offensive text",
"inoff",
]
y_train = ["offensive", "offensive", "inoffensive"]
# preparing the data
pos_cats = list()
for train_instance in y_train:
pos_cats.append({label: label == train_instance for label in unique_classes})
train_data = list(zip(x_train, [{"cats": cats} for cats in pos_cats]))
# set up the spacy model with a text categorizer component
nlp = spacy.blank("en")
textcat = nlp.create_pipe(
"textcat",
config={"exclusive_classes": True, "architecture": "bow", "ngram_size": 2},
)
for label in unique_classes:
textcat.add_label(label)
nlp.add_pipe(textcat, last=True)
# training the network
with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]):
optimizer = nlp.begin_training()
for i in range(3):
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 32.0, 1.001))
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(
docs=texts,
golds=annotations,
sgd=optimizer,
drop=0.1,
losses=losses,
)