Skip to content

Commit dc04681

Browse files
committed
tests: introduce test suite & ci
1 parent 3e451f9 commit dc04681

9 files changed

Lines changed: 200 additions & 0 deletions

File tree

.github/workflows/tests.yml

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
name: tests
2+
3+
on:
4+
push:
5+
pull_request:
6+
7+
jobs:
8+
pytest:
9+
runs-on: ubuntu-latest
10+
steps:
11+
- uses: actions/checkout@v4
12+
- uses: actions/setup-python@v5
13+
with:
14+
python-version: "3.11"
15+
- name: Install dependencies
16+
run: |
17+
python -m pip install --upgrade pip
18+
pip install tox torch numpy
19+
- name: Run tests
20+
run: tox -q

setup.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from pathlib import Path
2+
3+
from setuptools import find_packages, setup
4+
5+
6+
def read_version() -> str:
7+
version_file = Path(__file__).parent / "raptor" / "__init__.py"
8+
for line in version_file.read_text().splitlines():
9+
if line.startswith("__version__"):
10+
return line.split("=", 1)[1].strip().strip('"').strip("'")
11+
raise RuntimeError("Unable to find __version__ in raptor/__init__.py")
12+
13+
14+
setup(
15+
name="raptor",
16+
version=read_version(),
17+
description="Block-Recurrent Dynamics in ViTs (Raptor)",
18+
packages=find_packages(),
19+
python_requires=">=3.9",
20+
)

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pathlib import Path
2+
import sys
3+
4+
5+
ROOT = Path(__file__).resolve().parents[1]
6+
sys.path.insert(0, str(ROOT / "raptor"))
7+
sys.path.insert(0, str(ROOT))

tests/test_block.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import torch
2+
3+
from block import Block
4+
5+
6+
def test_block_forward_with_t_scale():
7+
block = Block(dim=16, num_heads=4, t_scale=True, swiglu=False)
8+
x = torch.randn(2, 6, 16)
9+
t = torch.tensor([1.0, 2.0])
10+
11+
out = block(x, t, t_integer=1)
12+
13+
assert out.shape == x.shape
14+
15+
16+
def test_block_forward_without_t_scale():
17+
block = Block(dim=16, num_heads=4, t_scale=False, swiglu=True)
18+
x = torch.randn(2, 6, 16)
19+
20+
out = block(x)
21+
22+
assert out.shape == x.shape

tests/test_dataloader.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
zarr = pytest.importorskip("zarr")
6+
pytest.importorskip("torchvision")
7+
pytest.importorskip("PIL")
8+
9+
from dataloader import AsynchZarrLoader, imagenet_transform
10+
11+
12+
def test_asynch_zarr_loader_yields_batches(tmp_path):
13+
data_path = tmp_path / "data.zarr"
14+
data = zarr.open(str(data_path), mode="w", shape=(10, 3, 2, 4), dtype="f4")
15+
data[:] = np.random.rand(10, 3, 2, 4).astype("f4")
16+
17+
loader = AsynchZarrLoader(
18+
zarr_path=str(data_path),
19+
layer_start=0,
20+
layer_end=1,
21+
batch_size=4,
22+
num_workers=1,
23+
queue_size=2,
24+
device="cpu",
25+
)
26+
27+
iterator = iter(loader)
28+
batch = next(iterator)
29+
loader.close()
30+
iterator.close()
31+
32+
assert batch.shape == (4, 2, 2, 4)
33+
assert batch.dtype == torch.float32
34+
35+
36+
def test_imagenet_transform_output_shape():
37+
from PIL import Image
38+
39+
img = Image.fromarray((np.random.rand(64, 64, 3) * 255).astype("uint8"))
40+
tensor = imagenet_transform()(img)
41+
42+
assert tensor.shape == (3, 224, 224)
43+
assert tensor.dtype == torch.float32

tests/test_raptor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from raptor import Raptor
5+
6+
7+
class DummyBlock(nn.Module):
8+
def __init__(self, value: float):
9+
super().__init__()
10+
self.value = value
11+
12+
def forward(self, x, t):
13+
return x + self.value
14+
15+
16+
def test_raptor_selects_expected_block():
17+
blocks = [DummyBlock(1.0), DummyBlock(2.0), DummyBlock(3.0)]
18+
model = Raptor(blocks, thresholds=[2, 5])
19+
x = torch.zeros(1, 1, 1)
20+
t = torch.tensor([1.0])
21+
22+
out1 = model(x, t, t_integer=1)
23+
out2 = model(x, t, t_integer=3)
24+
out3 = model(x, t, t_integer=9)
25+
26+
assert out1.item() == 1.0
27+
assert out2.item() == 2.0
28+
assert out3.item() == 3.0

tests/test_scheduler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
4+
from scheduler import CosineScheduler
5+
6+
7+
def test_cosine_scheduler_steps_match_schedule():
8+
model = torch.nn.Linear(2, 2)
9+
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
10+
scheduler = CosineScheduler(
11+
optimizer,
12+
base_value=0.1,
13+
final_value=0.01,
14+
total_iters=10,
15+
warmup_iters=2,
16+
)
17+
18+
assert len(scheduler.schedule) == 10
19+
for _ in range(3):
20+
scheduler.step()
21+
assert optimizer.param_groups[0]["lr"] == pytest.approx(scheduler[scheduler.iter])
22+
23+
assert scheduler[10] == pytest.approx(0.01)

tests/test_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
import utils
4+
5+
6+
def test_find_raptor_checkpoint_single_match(tmp_path):
7+
models_dir = tmp_path
8+
expected = models_dir / "final_raptor3_any_seed_42_step_123.pt"
9+
expected.write_text("ok")
10+
11+
result = utils.find_raptor_checkpoint("raptor3", 42, str(models_dir))
12+
13+
assert result == str(expected)
14+
15+
16+
def test_find_raptor_checkpoint_multiple_matches(tmp_path):
17+
models_dir = tmp_path
18+
(models_dir / "final_raptor2_a_seed_7_step_1.pt").write_text("ok")
19+
(models_dir / "final_raptor2_b_seed_7_step_2.pt").write_text("ok")
20+
21+
with pytest.raises(ValueError):
22+
utils.find_raptor_checkpoint("raptor2", 7, str(models_dir))
23+
24+
25+
def test_find_raptor_checkpoint_missing(tmp_path):
26+
with pytest.raises(FileNotFoundError):
27+
utils.find_raptor_checkpoint("raptor4", 999, str(tmp_path))

tox.ini

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[tox]
2+
envlist = py
3+
skipsdist = true
4+
5+
[testenv]
6+
deps =
7+
pytest
8+
commands =
9+
pytest -q
10+
sitepackages = true

0 commit comments

Comments
 (0)