@@ -44,10 +44,10 @@ def test_wrong_inputs():
44
44
)
45
45
46
46
47
- def _test (candidates , references , average , smooth = "no_smooth" , smooth_nltk_fn = None , ngram_range = 8 ):
47
+ def _test (candidates , references , average , smooth = "no_smooth" , smooth_nltk_fn = None , ngram_range = 8 , device = "cpu" ):
48
48
for i in range (1 , ngram_range ):
49
49
weights = tuple ([1 / i ] * i )
50
- bleu = Bleu (ngram = i , average = average , smooth = smooth )
50
+ bleu = Bleu (ngram = i , average = average , smooth = smooth , device = device )
51
51
52
52
if average == "macro" :
53
53
with warnings .catch_warnings ():
@@ -64,51 +64,56 @@ def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=No
64
64
assert pytest .approx (reference ) == bleu ._corpus_bleu (references , candidates )
65
65
66
66
bleu .update ((candidates , references ))
67
- assert pytest .approx (reference ) == bleu .compute ()
67
+ computed = bleu .compute ()
68
+ if isinstance (computed , torch .Tensor ):
69
+ computed = computed .cpu ().item ()
70
+ assert pytest .approx (reference ) == computed
68
71
69
72
70
73
@pytest .mark .parametrize (* parametrize_args )
71
- def test_macro_bleu (candidates , references ):
72
- _test (candidates , references , "macro" )
74
+ def test_macro_bleu (candidates , references , available_device ):
75
+ _test (candidates , references , "macro" , device = available_device )
73
76
74
77
75
78
@pytest .mark .parametrize (* parametrize_args )
76
- def test_micro_bleu (candidates , references ):
77
- _test (candidates , references , "micro" )
79
+ def test_micro_bleu (candidates , references , available_device ):
80
+ _test (candidates , references , "micro" , device = available_device )
78
81
79
82
80
83
@pytest .mark .parametrize (* parametrize_args )
81
- def test_macro_bleu_smooth1 (candidates , references ):
82
- _test (candidates , references , "macro" , "smooth1" , SmoothingFunction ().method1 )
84
+ def test_macro_bleu_smooth1 (candidates , references , available_device ):
85
+ _test (candidates , references , "macro" , "smooth1" , SmoothingFunction ().method1 , device = available_device )
83
86
84
87
85
88
@pytest .mark .parametrize (* parametrize_args )
86
- def test_micro_bleu_smooth1 (candidates , references ):
87
- _test (candidates , references , "micro" , "smooth1" , SmoothingFunction ().method1 )
89
+ def test_micro_bleu_smooth1 (candidates , references , available_device ):
90
+ _test (candidates , references , "micro" , "smooth1" , SmoothingFunction ().method1 , device = available_device )
88
91
89
92
90
93
@pytest .mark .parametrize (* parametrize_args )
91
- def test_macro_bleu_nltk_smooth2 (candidates , references ):
92
- _test (candidates , references , "macro" , "nltk_smooth2" , SmoothingFunction ().method2 )
94
+ def test_macro_bleu_nltk_smooth2 (candidates , references , available_device ):
95
+ _test (candidates , references , "macro" , "nltk_smooth2" , SmoothingFunction ().method2 , device = available_device )
93
96
94
97
95
98
@pytest .mark .parametrize (* parametrize_args )
96
- def test_micro_bleu_nltk_smooth2 (candidates , references ):
97
- _test (candidates , references , "micro" , "nltk_smooth2" , SmoothingFunction ().method2 )
99
+ def test_micro_bleu_nltk_smooth2 (candidates , references , available_device ):
100
+ _test (candidates , references , "micro" , "nltk_smooth2" , SmoothingFunction ().method2 , device = available_device )
98
101
99
102
100
103
@pytest .mark .parametrize (* parametrize_args )
101
- def test_macro_bleu_smooth2 (candidates , references ):
102
- _test (candidates , references , "macro" , "smooth2" , SmoothingFunction ().method2 , 3 )
104
+ def test_macro_bleu_smooth2 (candidates , references , available_device ):
105
+ _test (candidates , references , "macro" , "smooth2" , SmoothingFunction ().method2 , 3 , available_device )
103
106
104
107
105
108
@pytest .mark .parametrize (* parametrize_args )
106
- def test_micro_bleu_smooth2 (candidates , references ):
107
- _test (candidates , references , "micro" , "smooth2" , SmoothingFunction ().method2 , 3 )
109
+ def test_micro_bleu_smooth2 (candidates , references , available_device ):
110
+ _test (candidates , references , "micro" , "smooth2" , SmoothingFunction ().method2 , 3 , device = available_device )
108
111
109
112
110
- def test_accumulation_macro_bleu ():
111
- bleu = Bleu (ngram = 4 , smooth = "smooth2" )
113
+ def test_accumulation_macro_bleu (available_device ):
114
+ bleu = Bleu (ngram = 4 , smooth = "smooth2" , device = available_device )
115
+ assert bleu ._device == torch .device (available_device )
116
+
112
117
bleu .update (([corpus .cand_1 ], [corpus .references_1 ]))
113
118
bleu .update (([corpus .cand_2a ], [corpus .references_2 ]))
114
119
bleu .update (([corpus .cand_2b ], [corpus .references_2 ]))
@@ -120,8 +125,10 @@ def test_accumulation_macro_bleu():
120
125
assert bleu .compute () == value / 4
121
126
122
127
123
- def test_accumulation_micro_bleu ():
124
- bleu = Bleu (ngram = 4 , smooth = "smooth2" , average = "micro" )
128
+ def test_accumulation_micro_bleu (available_device ):
129
+ bleu = Bleu (ngram = 4 , smooth = "smooth2" , average = "micro" , device = available_device )
130
+ assert bleu ._device == torch .device (available_device )
131
+
125
132
bleu .update (([corpus .cand_1 ], [corpus .references_1 ]))
126
133
bleu .update (([corpus .cand_2a ], [corpus .references_2 ]))
127
134
bleu .update (([corpus .cand_2b ], [corpus .references_2 ]))
@@ -133,8 +140,9 @@ def test_accumulation_micro_bleu():
133
140
assert bleu .compute () == value
134
141
135
142
136
- def test_bleu_batch_macro ():
137
- bleu = Bleu (ngram = 4 )
143
+ def test_bleu_batch_macro (available_device ):
144
+ bleu = Bleu (ngram = 4 , device = available_device )
145
+ assert bleu ._device == torch .device (available_device )
138
146
139
147
# Batch size 3
140
148
hypotheses = [corpus .cand_1 , corpus .cand_2a , corpus .cand_2b ]
@@ -148,22 +156,29 @@ def test_bleu_batch_macro():
148
156
+ sentence_bleu (refs [1 ], hypotheses [1 ])
149
157
+ sentence_bleu (refs [2 ], hypotheses [2 ])
150
158
) / 3
151
- assert pytest .approx (bleu .compute ()) == reference_bleu_score
159
+ computed = bleu .compute ()
160
+ if isinstance (computed , torch .Tensor ):
161
+ computed = computed .cpu ().item ()
162
+
163
+ assert pytest .approx (computed ) == reference_bleu_score
152
164
153
165
value = 0
154
166
for _hypotheses , _refs in zip (hypotheses , refs ):
155
167
value += bleu ._sentence_bleu (_refs , _hypotheses )
156
168
bleu .update (([_hypotheses ], [_refs ]))
157
169
158
170
ref_1 = value / len (refs )
159
- ref_2 = bleu .compute ()
171
+ computed = bleu .compute ()
172
+ if isinstance (computed , torch .Tensor ):
173
+ computed = computed .cpu ().item ()
160
174
161
175
assert pytest .approx (ref_1 ) == reference_bleu_score
162
- assert pytest .approx (ref_2 ) == reference_bleu_score
176
+ assert pytest .approx (computed ) == reference_bleu_score
163
177
164
178
165
- def test_bleu_batch_micro ():
166
- bleu = Bleu (ngram = 4 , average = "micro" )
179
+ def test_bleu_batch_micro (available_device ):
180
+ bleu = Bleu (ngram = 4 , average = "micro" , device = available_device )
181
+ assert bleu ._device == torch .device (available_device )
167
182
168
183
# Batch size 3
169
184
hypotheses = [corpus .cand_1 , corpus .cand_2a , corpus .cand_2b ]
@@ -187,8 +202,10 @@ def test_bleu_batch_micro():
187
202
(corpus .cand_1 , corpus .references_1 ),
188
203
],
189
204
)
190
- def test_n_gram_counter (candidates , references ):
191
- bleu = Bleu (ngram = 4 )
205
+ def test_n_gram_counter (candidates , references , available_device ):
206
+ bleu = Bleu (ngram = 4 , device = available_device )
207
+ assert bleu ._device == torch .device (available_device )
208
+
192
209
hyp_length , ref_length = bleu ._n_gram_counter ([references ], [candidates ], Counter (), Counter ())
193
210
assert hyp_length == len (candidates )
194
211
@@ -212,9 +229,9 @@ def _test_macro_distrib_integration(device):
212
229
def update (_ , i ):
213
230
return data [i + size * rank ]
214
231
215
- def _test (metric_device ):
232
+ def _test (device ):
216
233
engine = Engine (update )
217
- m = Bleu (ngram = 4 , smooth = "smooth2" )
234
+ m = Bleu (ngram = 4 , smooth = "smooth2" , device = device )
218
235
m .attach (engine , "bleu" )
219
236
220
237
engine .run (data = list (range (size )), max_epochs = 1 )
@@ -256,7 +273,7 @@ def update(_, i):
256
273
257
274
def _test (metric_device ):
258
275
engine = Engine (update )
259
- m = Bleu (ngram = 4 , smooth = "smooth2" , average = "micro" )
276
+ m = Bleu (ngram = 4 , smooth = "smooth2" , average = "micro" , device = metric_device )
260
277
m .attach (engine , "bleu" )
261
278
262
279
engine .run (data = list (range (size )), max_epochs = 1 )
0 commit comments