Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve backbone weights loading mechanism #2454

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Fixed

- Improve backbone weights loading mechanism by @mzweilin in https://github.com/openvinotoolkit/anomalib/pull/2454

### New Contributors

## [v1.2.0]
Expand Down
8 changes: 3 additions & 5 deletions src/anomalib/models/components/feature_extractors/timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@ def __init__(
) -> None:
super().__init__()

# Extract backbone-name and weight-URI from the backbone string.
# Extract the backbone name and the weight file location from the backbone string.
if "__AT__" in backbone:
backbone, uri = backbone.split("__AT__")
pretrained_cfg = timm.models.registry.get_pretrained_cfg(backbone)
# Override pretrained_cfg["url"] to use different pretrained weights.
pretrained_cfg["url"] = uri
backbone, location = backbone.split("__AT__")
pretrained_cfg = {"url": location} if location.startswith(("http://", "https://")) else {"file": location}
else:
pretrained_cfg = None

Expand Down
4 changes: 3 additions & 1 deletion src/anomalib/models/image/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def __init__(
pre_trained=pre_trained,
).eval()
self.n_features_original = sum(self.feature_extractor.out_dims)
self.n_features = n_features or _N_FEATURES_DEFAULTS.get(self.backbone)
# In case the backbone has the weight file information.
backbone_name = self.backbone.split("__AT__")[0]
self.n_features = n_features or _N_FEATURES_DEFAULTS.get(backbone_name)
if self.n_features is None:
msg = (
f"n_features must be specified for backbone {self.backbone}. "
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/models/components/base/test_feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Unit tests for TimmFeatureExtractor."""

# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import tempfile
from pathlib import Path

import torch
from timm.models import create_model

from anomalib.models.components.feature_extractors.timm import TimmFeatureExtractor


def test_backbone_weight_file() -> None:
"""Test the backbone weight file loading mechanism."""
# Use the simplest model.
backbone = "resnet18"
_, state_dict_fpath = tempfile.mkstemp()
# Only examine conv1 before layers in the feature extractor.
layers = []

# Get random model weights without downloading.
model = create_model(backbone, pretrained=False)
state_dict = model.state_dict()

# Set the conv1 weight to zero, and save state_dict to a temp file.
state_dict["conv1.weight"].zero_()
torch.save(state_dict, state_dict_fpath)

# Load weights from the temp file.
backbone_with_path = f"{backbone}__AT__{state_dict_fpath}"

fe_restored = TimmFeatureExtractor(backbone_with_path, layers, pre_trained=True)
Path(state_dict_fpath).unlink()
# The weights should be zero if the file loading mechanism works.
assert torch.all(fe_restored.feature_extractor.conv1.weight == 0)
Loading