generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtests.py
85 lines (63 loc) · 2.12 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import pytest
from alr_transformer.alr_block import FeedForward, ALRBlock
# Create fixtures
@pytest.fixture
def sample_input():
return torch.randn(1, 1024, 512)
@pytest.fixture
def alrblock_model():
return ALRBlock(512, 2048, 0.1)
@pytest.fixture
def feedforward_model():
return FeedForward(512, 2048, 0.1)
# Tests for FeedForward class
def test_feedforward_creation():
model = FeedForward(512, 2048, 0.1)
assert isinstance(model, nn.Module)
def test_feedforward_forward(sample_input, feedforward_model):
output = feedforward_model(sample_input)
assert output.shape == sample_input.shape
# Tests for ALRBlock class
def test_alrblock_creation(alrblock_model):
assert isinstance(alrblock_model, nn.Module)
def test_alrblock_forward(sample_input, alrblock_model):
output = alrblock_model(sample_input)
assert output.shape == sample_input.shape
# Parameterized testing for various input dimensions and dropout rates
@pytest.mark.parametrize(
"input_dim, hidden_dim, dropout",
[
(256, 1024, 0.2),
(512, 2048, 0.0),
(128, 512, 0.3),
],
)
def test_feedforward_parameterized(input_dim, hidden_dim, dropout):
model = FeedForward(input_dim, hidden_dim, dropout)
input_tensor = torch.randn(1, 1024, input_dim)
output = model(input_tensor)
assert output.shape == input_tensor.shape
@pytest.mark.parametrize(
"dim, hidden_dim, dropout",
[
(256, 1024, 0.2),
(512, 2048, 0.0),
(128, 512, 0.3),
],
)
def test_alrblock_parameterized(dim, hidden_dim, dropout):
model = ALRBlock(dim, hidden_dim, dropout)
input_tensor = torch.randn(1, 1024, dim)
output = model(input_tensor)
assert output.shape == input_tensor.shape
# Exception testing
def test_feedforward_invalid_input():
model = FeedForward(512, 2048, 0.1)
with pytest.raises(RuntimeError):
model(torch.randn(2, 1024, 512)) # Invalid batch size
def test_alrblock_invalid_input():
model = ALRBlock(512, 2048, 0.1)
with pytest.raises(RuntimeError):
model(torch.randn(2, 1024, 512)) # Invalid batch size