Skip to content

Commit 2e280a5

Browse files
BanzaiTokyovfdev-5
andauthored
adds available device to nlp tests #3335 (#3385)
* adds available device to nlp tests #3335 * avoiding float64 * converts candidates and references float32 on MPS * more conversions to float32 on MPS * rolls back some unnecessary conversions to float32 * trying to make tests pass * rollback to previously passing tests * rollback _n_gram_counter parameter type change * in bleu.py do not use torch.double * clean up * sets dtype in bleu.py * adds return type to Bleu.compute * removes unnecessary conversion * typing * typing * transfer tensors in tests to cpu --------- Co-authored-by: vfdev <[email protected]>
1 parent 1ee848f commit 2e280a5

File tree

3 files changed

+63
-43
lines changed

3 files changed

+63
-43
lines changed

ignite/metrics/nlp/bleu.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Callable, Sequence, Tuple, Union
33

44
import torch
5+
from torch import Tensor
56

67
from ignite.exceptions import NotComputableError
78
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce
@@ -236,12 +237,12 @@ def _corpus_bleu(self, references: Sequence[Sequence[Sequence[Any]]], candidates
236237
@reinit__is_reduced
237238
def reset(self) -> None:
238239
if self.average == "macro":
239-
self._sum_of_bleu = torch.tensor(0.0, dtype=torch.double, device=self._device)
240+
self._sum_of_bleu = torch.tensor(0.0, dtype=self._double_dtype, device=self._device)
240241
self._num_sentences = 0
241242

242243
if self.average == "micro":
243-
self.p_numerators = torch.zeros(self.ngrams_order + 1)
244-
self.p_denominators = torch.zeros(self.ngrams_order + 1)
244+
self.p_numerators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
245+
self.p_denominators = torch.zeros(self.ngrams_order + 1, dtype=self._double_dtype)
245246
self.hyp_length_sum = 0
246247
self.ref_length_sum = 0
247248

@@ -278,8 +279,9 @@ def _compute_micro(self) -> float:
278279
)
279280
return bleu_score
280281

281-
def compute(self) -> None:
282+
def compute(self) -> Union[None, Tensor, float]:
282283
if self.average == "macro":
283284
return self._compute_macro()
284285
elif self.average == "micro":
285286
return self._compute_micro()
287+
return None

tests/ignite/metrics/nlp/test_bleu.py

+52-35
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def test_wrong_inputs():
4444
)
4545

4646

47-
def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8):
47+
def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=None, ngram_range=8, device="cpu"):
4848
for i in range(1, ngram_range):
4949
weights = tuple([1 / i] * i)
50-
bleu = Bleu(ngram=i, average=average, smooth=smooth)
50+
bleu = Bleu(ngram=i, average=average, smooth=smooth, device=device)
5151

5252
if average == "macro":
5353
with warnings.catch_warnings():
@@ -64,51 +64,56 @@ def _test(candidates, references, average, smooth="no_smooth", smooth_nltk_fn=No
6464
assert pytest.approx(reference) == bleu._corpus_bleu(references, candidates)
6565

6666
bleu.update((candidates, references))
67-
assert pytest.approx(reference) == bleu.compute()
67+
computed = bleu.compute()
68+
if isinstance(computed, torch.Tensor):
69+
computed = computed.cpu().item()
70+
assert pytest.approx(reference) == computed
6871

6972

7073
@pytest.mark.parametrize(*parametrize_args)
71-
def test_macro_bleu(candidates, references):
72-
_test(candidates, references, "macro")
74+
def test_macro_bleu(candidates, references, available_device):
75+
_test(candidates, references, "macro", device=available_device)
7376

7477

7578
@pytest.mark.parametrize(*parametrize_args)
76-
def test_micro_bleu(candidates, references):
77-
_test(candidates, references, "micro")
79+
def test_micro_bleu(candidates, references, available_device):
80+
_test(candidates, references, "micro", device=available_device)
7881

7982

8083
@pytest.mark.parametrize(*parametrize_args)
81-
def test_macro_bleu_smooth1(candidates, references):
82-
_test(candidates, references, "macro", "smooth1", SmoothingFunction().method1)
84+
def test_macro_bleu_smooth1(candidates, references, available_device):
85+
_test(candidates, references, "macro", "smooth1", SmoothingFunction().method1, device=available_device)
8386

8487

8588
@pytest.mark.parametrize(*parametrize_args)
86-
def test_micro_bleu_smooth1(candidates, references):
87-
_test(candidates, references, "micro", "smooth1", SmoothingFunction().method1)
89+
def test_micro_bleu_smooth1(candidates, references, available_device):
90+
_test(candidates, references, "micro", "smooth1", SmoothingFunction().method1, device=available_device)
8891

8992

9093
@pytest.mark.parametrize(*parametrize_args)
91-
def test_macro_bleu_nltk_smooth2(candidates, references):
92-
_test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2)
94+
def test_macro_bleu_nltk_smooth2(candidates, references, available_device):
95+
_test(candidates, references, "macro", "nltk_smooth2", SmoothingFunction().method2, device=available_device)
9396

9497

9598
@pytest.mark.parametrize(*parametrize_args)
96-
def test_micro_bleu_nltk_smooth2(candidates, references):
97-
_test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2)
99+
def test_micro_bleu_nltk_smooth2(candidates, references, available_device):
100+
_test(candidates, references, "micro", "nltk_smooth2", SmoothingFunction().method2, device=available_device)
98101

99102

100103
@pytest.mark.parametrize(*parametrize_args)
101-
def test_macro_bleu_smooth2(candidates, references):
102-
_test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3)
104+
def test_macro_bleu_smooth2(candidates, references, available_device):
105+
_test(candidates, references, "macro", "smooth2", SmoothingFunction().method2, 3, available_device)
103106

104107

105108
@pytest.mark.parametrize(*parametrize_args)
106-
def test_micro_bleu_smooth2(candidates, references):
107-
_test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3)
109+
def test_micro_bleu_smooth2(candidates, references, available_device):
110+
_test(candidates, references, "micro", "smooth2", SmoothingFunction().method2, 3, device=available_device)
108111

109112

110-
def test_accumulation_macro_bleu():
111-
bleu = Bleu(ngram=4, smooth="smooth2")
113+
def test_accumulation_macro_bleu(available_device):
114+
bleu = Bleu(ngram=4, smooth="smooth2", device=available_device)
115+
assert bleu._device == torch.device(available_device)
116+
112117
bleu.update(([corpus.cand_1], [corpus.references_1]))
113118
bleu.update(([corpus.cand_2a], [corpus.references_2]))
114119
bleu.update(([corpus.cand_2b], [corpus.references_2]))
@@ -120,8 +125,10 @@ def test_accumulation_macro_bleu():
120125
assert bleu.compute() == value / 4
121126

122127

123-
def test_accumulation_micro_bleu():
124-
bleu = Bleu(ngram=4, smooth="smooth2", average="micro")
128+
def test_accumulation_micro_bleu(available_device):
129+
bleu = Bleu(ngram=4, smooth="smooth2", average="micro", device=available_device)
130+
assert bleu._device == torch.device(available_device)
131+
125132
bleu.update(([corpus.cand_1], [corpus.references_1]))
126133
bleu.update(([corpus.cand_2a], [corpus.references_2]))
127134
bleu.update(([corpus.cand_2b], [corpus.references_2]))
@@ -133,8 +140,9 @@ def test_accumulation_micro_bleu():
133140
assert bleu.compute() == value
134141

135142

136-
def test_bleu_batch_macro():
137-
bleu = Bleu(ngram=4)
143+
def test_bleu_batch_macro(available_device):
144+
bleu = Bleu(ngram=4, device=available_device)
145+
assert bleu._device == torch.device(available_device)
138146

139147
# Batch size 3
140148
hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]
@@ -148,22 +156,29 @@ def test_bleu_batch_macro():
148156
+ sentence_bleu(refs[1], hypotheses[1])
149157
+ sentence_bleu(refs[2], hypotheses[2])
150158
) / 3
151-
assert pytest.approx(bleu.compute()) == reference_bleu_score
159+
computed = bleu.compute()
160+
if isinstance(computed, torch.Tensor):
161+
computed = computed.cpu().item()
162+
163+
assert pytest.approx(computed) == reference_bleu_score
152164

153165
value = 0
154166
for _hypotheses, _refs in zip(hypotheses, refs):
155167
value += bleu._sentence_bleu(_refs, _hypotheses)
156168
bleu.update(([_hypotheses], [_refs]))
157169

158170
ref_1 = value / len(refs)
159-
ref_2 = bleu.compute()
171+
computed = bleu.compute()
172+
if isinstance(computed, torch.Tensor):
173+
computed = computed.cpu().item()
160174

161175
assert pytest.approx(ref_1) == reference_bleu_score
162-
assert pytest.approx(ref_2) == reference_bleu_score
176+
assert pytest.approx(computed) == reference_bleu_score
163177

164178

165-
def test_bleu_batch_micro():
166-
bleu = Bleu(ngram=4, average="micro")
179+
def test_bleu_batch_micro(available_device):
180+
bleu = Bleu(ngram=4, average="micro", device=available_device)
181+
assert bleu._device == torch.device(available_device)
167182

168183
# Batch size 3
169184
hypotheses = [corpus.cand_1, corpus.cand_2a, corpus.cand_2b]
@@ -187,8 +202,10 @@ def test_bleu_batch_micro():
187202
(corpus.cand_1, corpus.references_1),
188203
],
189204
)
190-
def test_n_gram_counter(candidates, references):
191-
bleu = Bleu(ngram=4)
205+
def test_n_gram_counter(candidates, references, available_device):
206+
bleu = Bleu(ngram=4, device=available_device)
207+
assert bleu._device == torch.device(available_device)
208+
192209
hyp_length, ref_length = bleu._n_gram_counter([references], [candidates], Counter(), Counter())
193210
assert hyp_length == len(candidates)
194211

@@ -212,9 +229,9 @@ def _test_macro_distrib_integration(device):
212229
def update(_, i):
213230
return data[i + size * rank]
214231

215-
def _test(metric_device):
232+
def _test(device):
216233
engine = Engine(update)
217-
m = Bleu(ngram=4, smooth="smooth2")
234+
m = Bleu(ngram=4, smooth="smooth2", device=device)
218235
m.attach(engine, "bleu")
219236

220237
engine.run(data=list(range(size)), max_epochs=1)
@@ -256,7 +273,7 @@ def update(_, i):
256273

257274
def _test(metric_device):
258275
engine = Engine(update)
259-
m = Bleu(ngram=4, smooth="smooth2", average="micro")
276+
m = Bleu(ngram=4, smooth="smooth2", average="micro", device=metric_device)
260277
m.attach(engine, "bleu")
261278

262279
engine.run(data=list(range(size)), max_epochs=1)

tests/ignite/metrics/nlp/test_rouge.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ def test_wrong_inputs():
8484
(2, "abcdef", "zbdfz", (0, 0)),
8585
],
8686
)
87-
def test_rouge_n_alpha(ngram, candidate, reference, expected):
87+
def test_rouge_n_alpha(ngram, candidate, reference, expected, available_device):
8888
for alpha in [0, 1, 0.3, 0.5, 0.8]:
89-
rouge = RougeN(ngram=ngram, alpha=alpha)
89+
rouge = RougeN(ngram=ngram, alpha=alpha, device=available_device)
9090
rouge.update(([candidate], [[reference]]))
9191
results = rouge.compute()
9292
assert results[f"Rouge-{ngram}-P"] == expected[0]
@@ -101,7 +101,7 @@ def test_rouge_n_alpha(ngram, candidate, reference, expected):
101101
@pytest.mark.parametrize(
102102
"candidates, references", [corpus.sample_1, corpus.sample_2, corpus.sample_3, corpus.sample_4, corpus.sample_5]
103103
)
104-
def test_rouge_metrics(candidates, references):
104+
def test_rouge_metrics(candidates, references, available_device):
105105
for multiref in ["average", "best"]:
106106
# PERL 1.5.5 reference
107107
apply_avg = multiref == "average"
@@ -123,7 +123,8 @@ def test_rouge_metrics(candidates, references):
123123

124124
lower_split_candidates = [candidate.lower().split() for candidate in candidates]
125125

126-
m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5)
126+
m = Rouge(variants=[1, 2, 4, "L"], multiref=multiref, alpha=0.5, device=available_device)
127+
assert m._device == torch.device(available_device)
127128
m.update((lower_split_candidates, lower_split_references))
128129
results = m.compute()
129130

0 commit comments

Comments
 (0)