Skip to content

Commit 900ac49

Browse files
authored
Refactor tests (qubvel-org#1011)
* Add parallel test deps * Update signature * Add encoders tests * Update gitignore * Update encoders for timm-universal * Add parallel tests run * Disable models tests * Add uv to CI * Add uv to minimum * Add show-install-packages * Increase to 3 workers * Fix show-packages * Change back for 2 workers * Add coverage * Basic model test * Fix * Move model archs * Add base params test * Fix timm test for minimum version * Remove deprecated utils from coverage * Fix * Fix * Exclude conversion script * Add save-load test, add aux head test * Remove custom encoder * Set encoder for models tests * Docs + flag for anyres * Fix loading from config * Bump min hf-hub to 0.25.0 * Fix minimal * Add test with hub checkpoint * Fixing minimum * Fix * Fix torch for minimum tests * Update torch version and run-slow * run-slow * Show skipped * [run-slow] Fixing minimum * [run-slow] Fixing minimum * Fix decorator * Raise error * [run-slow] Fixing run slow * [run-slow] Fixing run slow * Run slow tests in separate job * FIx * Fixes * Add device * Bum tolerance * Add device * Fixup
1 parent fbeeb0c commit 900ac49

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1046
-223
lines changed

.github/workflows/tests.yml

+43-6
Original file line numberDiff line numberDiff line change
@@ -36,24 +36,61 @@ jobs:
3636
runs-on: ${{ matrix.os }}
3737
steps:
3838
- uses: actions/checkout@v4
39+
3940
- name: Set up Python ${{ matrix.python-version }}
4041
uses: actions/setup-python@v5
4142
with:
4243
python-version: ${{ matrix.python-version }}
44+
4345
- name: Install dependencies
44-
run: python -m pip install -r requirements/required.txt -r requirements/test.txt
45-
- name: Test with pytest
46-
run: pytest
46+
run: |
47+
python -m pip install uv
48+
python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt
49+
50+
- name: Show installed packages
51+
run: |
52+
python -m pip list
53+
54+
- name: Test with PyTest
55+
run: |
56+
pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match"
57+
58+
- name: Upload coverage reports to Codecov
59+
uses: codecov/codecov-action@v5
60+
with:
61+
token: ${{ secrets.CODECOV_TOKEN }}
62+
slug: qubvel-org/segmentation_models.pytorch
63+
if: matrix.os == 'macos-latest' && matrix.python-version == '3.12'
64+
65+
test_logits_match:
66+
runs-on: ubuntu-latest
67+
steps:
68+
- uses: actions/checkout@v4
69+
- name: Set up Python
70+
uses: actions/setup-python@v5
71+
with:
72+
python-version: "3.10"
73+
- name: Install dependencies
74+
run: |
75+
python -m pip install uv
76+
python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt
77+
- name: Test with PyTest
78+
run: RUN_SLOW=1 pytest -v -rsx -n 2 -k "logits_match"
4779

4880
minimum:
4981
runs-on: ubuntu-latest
5082
steps:
5183
- uses: actions/checkout@v4
52-
- name: Set up Python ${{ matrix.python-version }}
84+
- name: Set up Python
5385
uses: actions/setup-python@v5
5486
with:
5587
python-version: "3.9"
5688
- name: Install dependencies
57-
run: python -m pip install -r requirements/minimum.old -r requirements/test.txt
89+
run: |
90+
python -m pip install uv
91+
python -m uv pip install --system -r requirements/minimum.old -r requirements/test.txt
92+
- name: Show installed packages
93+
run: |
94+
python -m pip list
5895
- name: Test with pytest
59-
run: pytest
96+
run: pytest -v -rsx -n 2 -k "not logits_match"

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ venv/
9393
ENV/
9494
env.bak/
9595
venv.bak/
96+
.vscode/
9697

9798
# Spyder project settings
9899
.spyderproject

Makefile

+4-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ install_dev: .venv
77
.venv/bin/pip install -e ".[test]"
88

99
test: .venv
10-
.venv/bin/pytest -p no:cacheprovider tests/
10+
.venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match"
11+
12+
test_all: .venv
13+
RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/
1114

1215
table:
1316
.venv/bin/python misc/generate_table.py

misc/generate_test_models.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import os
2+
import torch
3+
import tempfile
4+
import huggingface_hub
5+
import segmentation_models_pytorch as smp
6+
7+
HUB_REPO = "smp-test-models"
8+
ENCODER_NAME = "tu-resnet18"
9+
10+
api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN"))
11+
12+
for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
13+
model = model_class(encoder_name=ENCODER_NAME)
14+
model = model.eval()
15+
16+
# generate test sample
17+
torch.manual_seed(423553)
18+
sample = torch.rand(1, 3, 256, 256)
19+
20+
with torch.no_grad():
21+
output = model(sample)
22+
23+
with tempfile.TemporaryDirectory() as tmpdir:
24+
# save model
25+
model.save_pretrained(f"{tmpdir}")
26+
27+
# save input and output
28+
torch.save(sample, f"{tmpdir}/input-tensor.pth")
29+
torch.save(output, f"{tmpdir}/output-tensor.pth")
30+
31+
# create repo
32+
repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}"
33+
if not api.repo_exists(repo_id=repo_id):
34+
api.create_repo(repo_id=repo_id, repo_type="model")
35+
36+
# upload to hub
37+
api.upload_folder(
38+
folder_path=tmpdir,
39+
repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}",
40+
repo_type="model",
41+
)

pyproject.toml

+24
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ docs = [
4040
]
4141
test = [
4242
'pytest',
43+
'pytest-cov',
44+
'pytest-xdist',
4345
'ruff',
4446
]
4547

@@ -55,3 +57,25 @@ version = {attr = 'segmentation_models_pytorch.__version__.__version__'}
5557

5658
[tool.setuptools.packages.find]
5759
include = ['segmentation_models_pytorch*']
60+
61+
[tool.pytest.ini_options]
62+
markers = [
63+
"deeplabv3",
64+
"deeplabv3plus",
65+
"fpn",
66+
"linknet",
67+
"manet",
68+
"pan",
69+
"psp",
70+
"segformer",
71+
"unet",
72+
"unetplusplus",
73+
"upernet",
74+
"logits_match",
75+
]
76+
77+
[tool.coverage.run]
78+
omit = [
79+
"segmentation_models_pytorch/utils/*",
80+
"**/convert_*",
81+
]

requirements/minimum.old

+1
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ timm==0.9.0
88
torch==1.9.0
99
torchvision==0.10.0
1010
tqdm==4.42.1
11+
Jinja2==3.0.0

requirements/test.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
pytest==8.3.4
2-
ruff==0.8.4
2+
pytest-xdist==3.6.1
3+
pytest-cov==6.0.0
4+
ruff==0.8.4

segmentation_models_pytorch/__init__.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,21 @@
3030
"ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning
3131
) # for python >= 3.12
3232

33+
_MODEL_ARCHITECTURES = [
34+
Unet,
35+
UnetPlusPlus,
36+
MAnet,
37+
Linknet,
38+
FPN,
39+
PSPNet,
40+
DeepLabV3,
41+
DeepLabV3Plus,
42+
PAN,
43+
UPerNet,
44+
Segformer,
45+
]
46+
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}
47+
3348

3449
def create_model(
3550
arch: str,
@@ -43,26 +58,12 @@ def create_model(
4358
parameters, without using its class
4459
"""
4560

46-
archs = [
47-
Unet,
48-
UnetPlusPlus,
49-
MAnet,
50-
Linknet,
51-
FPN,
52-
PSPNet,
53-
DeepLabV3,
54-
DeepLabV3Plus,
55-
PAN,
56-
UPerNet,
57-
Segformer,
58-
]
59-
archs_dict = {a.__name__.lower(): a for a in archs}
6061
try:
61-
model_class = archs_dict[arch.lower()]
62+
model_class = MODEL_ARCHITECTURES_MAPPING[arch.lower()]
6263
except KeyError:
6364
raise KeyError(
6465
"Wrong architecture type `{}`. Available options are: {}".format(
65-
arch, list(archs_dict.keys())
66+
arch, list(MODEL_ARCHITECTURES_MAPPING.keys())
6667
)
6768
)
6869
return model_class(

segmentation_models_pytorch/base/hub_mixin.py

+11
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,14 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):
136136

137137
model_class = getattr(smp, model_class_name)
138138
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
139+
140+
141+
def supports_config_loading(func):
142+
"""Decorator to filter special config kwargs"""
143+
144+
@wraps(func)
145+
def wrapper(self, *args, **kwargs):
146+
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
147+
return func(self, *args, **kwargs)
148+
149+
return wrapper

segmentation_models_pytorch/base/model.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@
55

66

77
class SegmentationModel(torch.nn.Module, SMPHubMixin):
8+
"""Base class for all segmentation models."""
9+
10+
# if model supports shape not divisible by 2 ^ n
11+
# set to False
12+
requires_divisible_input_shape = True
13+
814
def initialize(self):
915
init.initialize_decoder(self.decoder)
1016
init.initialize_head(self.segmentation_head)
1117
if self.classification_head is not None:
1218
init.initialize_head(self.classification_head)
1319

1420
def check_input_shape(self, x):
21+
"""Check if the input shape is divisible by the output stride.
22+
If not, raise a RuntimeError.
23+
"""
1524
h, w = x.shape[-2:]
1625
output_stride = self.encoder.output_stride
1726
if h % output_stride != 0 or w % output_stride != 0:
@@ -33,7 +42,7 @@ def check_input_shape(self, x):
3342
def forward(self, x):
3443
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
3544

36-
if not torch.jit.is_tracing():
45+
if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
3746
self.check_input_shape(x)
3847

3948
features = self.encoder(x)

segmentation_models_pytorch/decoders/deeplabv3/model.py

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
SegmentationModel,
99
)
1010
from segmentation_models_pytorch.encoders import get_encoder
11+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
1112

1213
from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder
1314

@@ -54,6 +55,7 @@ class DeepLabV3(SegmentationModel):
5455
5556
"""
5657

58+
@supports_config_loading
5759
def __init__(
5860
self,
5961
encoder_name: str = "resnet34",
@@ -163,6 +165,7 @@ class DeepLabV3Plus(SegmentationModel):
163165
164166
"""
165167

168+
@supports_config_loading
166169
def __init__(
167170
self,
168171
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/fpn/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import FPNDecoder
1112

@@ -51,6 +52,7 @@ class FPN(SegmentationModel):
5152
5253
"""
5354

55+
@supports_config_loading
5456
def __init__(
5557
self,
5658
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/linknet/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import LinknetDecoder
1112

@@ -53,6 +54,7 @@ class Linknet(SegmentationModel):
5354
https://arxiv.org/abs/1707.03718
5455
"""
5556

57+
@supports_config_loading
5658
def __init__(
5759
self,
5860
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/manet/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import MAnetDecoder
1112

@@ -56,6 +57,7 @@ class MAnet(SegmentationModel):
5657
5758
"""
5859

60+
@supports_config_loading
5961
def __init__(
6062
self,
6163
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/pan/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import PANDecoder
1112

@@ -53,6 +54,7 @@ class PAN(SegmentationModel):
5354
5455
"""
5556

57+
@supports_config_loading
5658
def __init__(
5759
self,
5860
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/pspnet/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import PSPDecoder
1112

@@ -54,6 +55,7 @@ class PSPNet(SegmentationModel):
5455
https://arxiv.org/abs/1612.01105
5556
"""
5657

58+
@supports_config_loading
5759
def __init__(
5860
self,
5961
encoder_name: str = "resnet34",

segmentation_models_pytorch/decoders/segformer/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
SegmentationModel,
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
910

1011
from .decoder import SegformerDecoder
1112

@@ -46,6 +47,7 @@ class Segformer(SegmentationModel):
4647
4748
"""
4849

50+
@supports_config_loading
4951
def __init__(
5052
self,
5153
encoder_name: str = "resnet34",

0 commit comments

Comments
 (0)