Skip to content

Commit 3e41559

Browse files
authored
New models for Anomaly Detection (#470)
* new models for anomaly detection * style checks * notebook testing removed * added hub option * update anomaly detection models * addressed review comments * torcvision --> self._hub * changed hardcoded hub * fix integration * linter * added test for 3 new models
1 parent b6713f1 commit 3e41559

14 files changed

+663
-39
lines changed

Models.md

+66
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,74 @@
170170

171171
| Model name | Framework | Model Hub |
172172
|------------|-----------|-----------|
173+
| alexnet | PyTorch | Torchvision |
174+
| convnext_base | PyTorch | Torchvision |
175+
| convnext_large | PyTorch | Torchvision |
176+
| convnext_small | PyTorch | Torchvision |
177+
| convnext_tiny | PyTorch | Torchvision |
178+
| densenet121 | PyTorch | Torchvision |
179+
| densenet161 | PyTorch | Torchvision |
180+
| densenet169 | PyTorch | Torchvision |
181+
| densenet201 | PyTorch | Torchvision |
182+
| dinov2_vitb14 | PyTorch | PyTorch Hub |
183+
| dinov2_vitb14_reg | PyTorch | PyTorch Hub |
184+
| dinov2_vitg14 | PyTorch | PyTorch Hub |
185+
| dinov2_vitg14_reg | PyTorch | PyTorch Hub |
186+
| dinov2_vitl14 | PyTorch | PyTorch Hub |
187+
| dinov2_vitl14_reg | PyTorch | PyTorch Hub |
188+
| dinov2_vits14 | PyTorch | PyTorch Hub |
189+
| dinov2_vits14_reg | PyTorch | PyTorch Hub |
190+
| efficientnet_b0 | PyTorch | Torchvision |
191+
| efficientnet_b1 | PyTorch | Torchvision |
192+
| efficientnet_b2 | PyTorch | Torchvision |
193+
| efficientnet_b3 | PyTorch | Torchvision |
194+
| efficientnet_b4 | PyTorch | Torchvision |
195+
| efficientnet_b5 | PyTorch | Torchvision |
196+
| efficientnet_b6 | PyTorch | Torchvision |
197+
| efficientnet_b7 | PyTorch | Torchvision |
198+
| googlenet | PyTorch | Torchvision |
199+
| mnasnet0_5 | PyTorch | Torchvision |
200+
| mnasnet1_0 | PyTorch | Torchvision |
201+
| mobilenet_v2 | PyTorch | Torchvision |
202+
| mobilenet_v3_large | PyTorch | Torchvision |
203+
| mobilenet_v3_small | PyTorch | Torchvision |
204+
| regnet_x_16gf | PyTorch | Torchvision |
205+
| regnet_x_1_6gf | PyTorch | Torchvision |
206+
| regnet_x_32gf | PyTorch | Torchvision |
207+
| regnet_x_3_2gf | PyTorch | Torchvision |
208+
| regnet_x_400mf | PyTorch | Torchvision |
209+
| regnet_x_800mf | PyTorch | Torchvision |
210+
| regnet_x_8gf | PyTorch | Torchvision |
211+
| regnet_y_16gf | PyTorch | Torchvision |
212+
| regnet_y_1_6gf | PyTorch | Torchvision |
213+
| regnet_y_32gf | PyTorch | Torchvision |
214+
| regnet_y_3_2gf | PyTorch | Torchvision |
215+
| regnet_y_400mf | PyTorch | Torchvision |
216+
| regnet_y_800mf | PyTorch | Torchvision |
217+
| regnet_y_8gf | PyTorch | Torchvision |
218+
| resnet101 | PyTorch | Torchvision |
219+
| resnet152 | PyTorch | Torchvision |
173220
| resnet18 | PyTorch | Torchvision |
221+
| resnet34 | PyTorch | Torchvision |
174222
| resnet50 | PyTorch | Torchvision |
223+
| resnext101_32x8d | PyTorch | Torchvision |
224+
| resnext50_32x4d | PyTorch | Torchvision |
225+
| shufflenet_v2_x0_5 | PyTorch | Torchvision |
226+
| shufflenet_v2_x1_0 | PyTorch | Torchvision |
227+
| vgg11 | PyTorch | Torchvision |
228+
| vgg11_bn | PyTorch | Torchvision |
229+
| vgg13 | PyTorch | Torchvision |
230+
| vgg13_bn | PyTorch | Torchvision |
231+
| vgg16 | PyTorch | Torchvision |
232+
| vgg16_bn | PyTorch | Torchvision |
233+
| vgg19 | PyTorch | Torchvision |
234+
| vgg19_bn | PyTorch | Torchvision |
235+
| vit_b_16 | PyTorch | Torchvision |
236+
| vit_b_32 | PyTorch | Torchvision |
237+
| vit_l_16 | PyTorch | Torchvision |
238+
| vit_l_32 | PyTorch | Torchvision |
239+
| wide_resnet101_2 | PyTorch | Torchvision |
240+
| wide_resnet50_2 | PyTorch | Torchvision |
175241

176242
## Text Generation
177243

downloader/models.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self, model_name, hub, model_dir=None, hf_model_class=None, **kwarg
5656
self._model_name = model_name
5757
self._model_dir = model_dir
5858
self._hf_model_class = hf_model_class
59+
self._use_case = kwargs.get("use_case", None)
5960
self._type = ModelType.from_str(hub)
6061
self._args = kwargs
6162

@@ -88,9 +89,12 @@ def download(self):
8889

8990
if self._model_dir is not None:
9091
os.environ['TORCH_HOME'] = self._model_dir
91-
92-
config_file = os.path.join(TLT_BASE_DIR, "models/configs/pytorch_hub_image_classification_models.json")
93-
pytorch_hub_model_map = read_json_file(config_file)
92+
if self._use_case is not None:
93+
if self._use_case.lower().strip() in ["anomaly detection", "ad", "anomaly_detection"]:
94+
config_f = os.path.join(TLT_BASE_DIR, "models/configs/pytorch_hub_image_anomaly_detection_models.json") # noqa: E501
95+
else:
96+
config_f = os.path.join(TLT_BASE_DIR, "models/configs/pytorch_hub_image_classification_models.json")
97+
pytorch_hub_model_map = read_json_file(config_f)
9498
self._repo = pytorch_hub_model_map[self._model_name]["repo"]
9599

96100
# Some models have pretrained=True by default, which error out if passed in load()

notebooks/image_anomaly_detection/tlt_api_pyt_anomaly_detection/Anomaly_Detection.ipynb

+27-16
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,17 @@
7575
"Here we are getting the pretrained `resnet50` model from Torchvision:"
7676
]
7777
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"id": "bb238c2a-8eea-4c4f-8586-fa7c0305f85a",
82+
"metadata": {},
83+
"outputs": [],
84+
"source": [
85+
"model_factory.print_supported_models(framework=\"pytorch\", use_case=\"anomaly_detection\", verbose = False,\n",
86+
" markdown=False)"
87+
]
88+
},
7889
{
7990
"cell_type": "code",
8091
"execution_count": null,
@@ -338,18 +349,6 @@
338349
"To train a Simsiam model, it is required to apply a TwoCropTransform augmentation technique on the dataset used for training. You can preview this augmentation on a sample batch after preprocessing by using `get_batch(simsiam=True)` and then use them for simsiam training by using `simsiam=True` in `model.train()` also."
339350
]
340351
},
341-
{
342-
"cell_type": "code",
343-
"execution_count": null,
344-
"id": "8cd9420d",
345-
"metadata": {},
346-
"outputs": [],
347-
"source": [
348-
"# Examine the model's layers and decide which to use for feature extraction\n",
349-
"model.list_layers(verbose=False)\n",
350-
"layer = 'layer3'"
351-
]
352-
},
353352
{
354353
"cell_type": "code",
355354
"execution_count": null,
@@ -439,7 +438,7 @@
439438
"id": "759bc3ea",
440439
"metadata": {},
441440
"source": [
442-
"### Train Arguments\n",
441+
"## Train Arguments\n",
443442
"\n",
444443
"#### Required\n",
445444
"- **dataset** (ImageAnomalyDetectionDataset, required): Dataset to use when training the model\n",
@@ -467,15 +466,27 @@
467466
"Note: refer to release documentation for an up-to-date list of train arguments and their current descriptions"
468467
]
469468
},
469+
{
470+
"cell_type": "code",
471+
"execution_count": null,
472+
"id": "8cd9420d",
473+
"metadata": {},
474+
"outputs": [],
475+
"source": [
476+
"# Examine the model's layers and decide which to use for feature extraction\n",
477+
"model.list_layers(verbose=False)\n",
478+
"layer = 'layer3'"
479+
]
480+
},
470481
{
471482
"cell_type": "code",
472483
"execution_count": null,
473484
"id": "a2b601fc",
474485
"metadata": {},
475486
"outputs": [],
476487
"source": [
477-
"pca_components, trained_model = model.train(dataset, output_dir, layer_name=layer, epochs=2,\n",
478-
" seed=None, pooling='avg', kernel_size=2, pca_threshold=0.99)"
488+
"pca_components, trained_model = model.train(dataset, output_dir, epochs=2, layer_name=layer,\n",
489+
" seed=None, pooling='avg', kernel_size=2, pca_threshold=0.99)"
479490
]
480491
},
481492
{
@@ -635,7 +646,7 @@
635646
"name": "python",
636647
"nbconvert_exporter": "python",
637648
"pygments_lexer": "ipython3",
638-
"version": "3.8.10"
649+
"version": "3.9.17"
639650
}
640651
},
641652
"nbformat": 4,

notebooks/image_classification/tlt_api_pyt_image_classification/TLT_PyTorch_Image_Classification_Transfer_Learning.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
"outputs": [],
102102
"source": [
103103
"# Set device=\"hpu\" to use Gaudi. If no HPU hardware or installs are detected, device will default to \"cpu\"\n",
104-
"model = model_factory.get_model(model_name='efficientnet_b0', framework='pytorch', device='cpu')\n",
104+
"model = model_factory.get_model(model_name='efficientnet_b1', framework='pytorch', device='cpu')\n",
105105
"\n",
106106
"print(\"Model name:\", model.model_name)\n",
107107
"print(\"Framework:\", model.framework)\n",

tests/pytorch_tests/test_image_anomaly_detection.py

+44
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,47 @@ def test_cutpaste_workflow_benchmark(self, model_name):
330330

331331
# Benchmark
332332
model.benchmark(dataset=dataset)
333+
334+
@pytest.mark.parametrize('model_name',
335+
['vgg11',
336+
'efficientnet_b1',
337+
'alexnet'])
338+
def test_no_simsiam_or_cutpaste(self, model_name):
339+
"""
340+
Tests the workflow for PYT image anomaly detection using a custom dataset
341+
and cutpaste feature extractor enabled
342+
"""
343+
framework = 'pytorch'
344+
use_case = 'image_anomaly_detection'
345+
346+
# Get the dataset
347+
dataset = dataset_factory.load_dataset(self._dataset_dir, use_case=use_case, framework=framework,
348+
shuffle_files=False)
349+
assert ['tulips'] == dataset.defect_names
350+
assert ['bad', 'good'] == dataset.class_names
351+
352+
# Get the model
353+
model = model_factory.get_model(model_name, framework, use_case)
354+
355+
# Preprocess the dataset and split to get small subsets for training and validation
356+
dataset.preprocess(model.image_size, 32)
357+
dataset.shuffle_split(train_pct=0.5, val_pct=0.25, test_pct=0.25, seed=10)
358+
359+
# Train for 1 epoch
360+
pca_components, trained_model = model.train(dataset, self._output_dir, epochs=1, kernel_size=2,
361+
layer_name='features', pooling='avg', pca_threshold=0.99,
362+
seed=10)
363+
364+
# Evaluate
365+
threshold, auroc = model.evaluate(dataset, use_test_set=True)
366+
assert isinstance(auroc, float)
367+
368+
# Predict with a batch
369+
images, labels = dataset.get_batch(subset='test')
370+
predictions = model.predict(images, pca_mats=pca_components)
371+
assert len(predictions) == 32
372+
373+
# Export the saved model
374+
saved_model_dir = model.export(self._output_dir)
375+
assert os.path.isdir(saved_model_dir)
376+
assert os.path.isfile(os.path.join(saved_model_dir, "model.pt"))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
{
2+
"dinov2_vitg14": {
3+
"model_hub": "pytorch_hub",
4+
"repo": "facebookresearch/dinov2",
5+
"pretrained_default": "False",
6+
"classification_layer": ["head"],
7+
"image_size": 224,
8+
"original_dataset": "IMAGENET1K_V1"
9+
},
10+
"dinov2_vits14": {
11+
"model_hub": "pytorch_hub",
12+
"repo": "facebookresearch/dinov2",
13+
"pretrained_default": "True",
14+
"classification_layer": ["head"],
15+
"image_size": 224,
16+
"original_dataset": "IMAGENET1K_V1"
17+
},
18+
"dinov2_vitb14": {
19+
"model_hub": "pytorch_hub",
20+
"repo": "facebookresearch/dinov2",
21+
"pretrained_default": "True",
22+
"classification_layer": ["head"],
23+
"image_size": 224,
24+
"original_dataset": "IMAGENET1K_V1"
25+
},
26+
"dinov2_vitl14": {
27+
"model_hub": "pytorch_hub",
28+
"repo": "facebookresearch/dinov2",
29+
"pretrained_default": "True",
30+
"classification_layer": ["head"],
31+
"image_size": 224,
32+
"original_dataset": "IMAGENET1K_V1"
33+
},
34+
"dinov2_vits14_reg": {
35+
"model_hub": "pytorch_hub",
36+
"repo": "facebookresearch/dinov2",
37+
"pretrained_default": "True",
38+
"classification_layer": ["head"],
39+
"image_size": 224,
40+
"original_dataset": "IMAGENET1K_V1"
41+
},
42+
"dinov2_vitg14_reg": {
43+
"model_hub": "pytorch_hub",
44+
"repo": "facebookresearch/dinov2",
45+
"pretrained_default": "True",
46+
"classification_layer": ["head"],
47+
"image_size": 224,
48+
"original_dataset": "IMAGENET1K_V1"
49+
},
50+
"dinov2_vitl14_reg": {
51+
"model_hub": "pytorch_hub",
52+
"repo": "facebookresearch/dinov2",
53+
"pretrained_default": "True",
54+
"classification_layer": ["head"],
55+
"image_size": 224,
56+
"original_dataset": "IMAGENET1K_V1"
57+
},
58+
"dinov2_vitb14_reg": {
59+
"model_hub": "pytorch_hub",
60+
"repo": "facebookresearch/dinov2",
61+
"pretrained_default": "True",
62+
"classification_layer": ["head"],
63+
"image_size": 224,
64+
"original_dataset": "IMAGENET1K_V1"
65+
}
66+
}

0 commit comments

Comments
 (0)