-
Notifications
You must be signed in to change notification settings - Fork 132
/
Copy pathtest_basic.py
38 lines (27 loc) · 1.09 KB
/
test_basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import torch
import pytest
from train import compute_accuracy
def test_arange_elems():
arr = torch.arange(0, 10, dtype=torch.float)
assert torch.allclose(arr[-1], torch.tensor([9]).float()) #
def test_div_zero():
a = torch.zeros(1,dtype=torch.long)
b = torch.ones(1,dtype=torch.long)
assert not torch.isfinite(b/a)
def test_div_zero_python():
with pytest.raises(ZeroDivisionError):
1/0 #
def test_accuracy():
preds = torch.randint(0,2,size=(100,))
targets = preds.clone()
assert compute_accuracy(preds, targets) == 1.0
preds = torch.tensor([1,2,3,0,0,0])
targets = torch.tensor([1,2,3,4,5,6])
assert compute_accuracy(preds, targets) == 0.5 # This is bad - why?
@pytest.mark.parametrize("preds,targets,result",[
(torch.tensor([1,2,3]),torch.tensor([1,2,3]), 1.0),
(torch.tensor([1,2,3]),torch.tensor([0,0,0]), 0.0),
(torch.tensor([1,2,3]),torch.tensor([1,2,0]), 2/3),
])
def test_accuracy_parametrized(preds, targets, result):
assert torch.allclose(compute_accuracy(preds, targets), torch.tensor([result]), rtol=0, atol=1e-5)