Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow passing metrics objects directly to create_metrics_collection #2212

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Changed

- πŸ”¨ Allow passing metrics objects directly to `create_metrics_collection` by @ashwinvaidya17 in https://github.com/openvinotoolkit/anomalib/pull/2212

### Deprecated

### Fixed
Expand Down
21 changes: 14 additions & 7 deletions src/anomalib/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torchmetrics
from omegaconf import DictConfig, ListConfig
from torchmetrics import Metric

from .anomaly_score_distribution import AnomalyScoreDistribution
from .aupr import AUPR
Expand Down Expand Up @@ -162,7 +163,7 @@


def create_metric_collection(
metrics: list[str] | dict[str, dict[str, Any]],
metrics: list[str] | dict[str, dict[str, Any]] | Metric | list[Metric],
prefix: str | None = None,
) -> AnomalibMetricCollection:
"""Create a metric collection from a list of metric names or dictionaries.
Expand All @@ -171,25 +172,31 @@

- if list[str] (names of metrics): see `metric_collection_from_names`
- if dict[str, dict[str, Any]] (path and init args of a class): see `metric_collection_from_dicts`
- if list[Metric] (metric objects): A collection is returned with those metrics.

The function will first try to retrieve the metric from the metrics defined in Anomalib metrics module,
then in TorchMetrics package.

Args:
metrics (list[str] | dict[str, dict[str, Any]]): List of metrics or dictionaries to create metric collection.
metrics (list[str] | dict[str, dict[str, Any]] | Metric | list[Metric]): List of metrics or dictionaries to
create metric collection.
prefix (str | None): Prefix to assign to the metrics in the collection.

Returns:
AnomalibMetricCollection: Collection of metrics.
"""
# fallback is using the names

if isinstance(metrics, ListConfig | list):
if not all(isinstance(metric, str) for metric in metrics):
msg = f"All metrics must be strings, found {metrics}"
if not (
all(isinstance(metric, str) for metric in metrics) or all(isinstance(metric, Metric) for metric in metrics)
):
msg = f"All metrics must be either string or Metric objects, found {metrics}"

Check warning on line 192 in src/anomalib/metrics/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/metrics/__init__.py#L192

Added line #L192 was not covered by tests
Comment on lines +189 to +192
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this mean that a user cannot pass the following:

from torchmetrics.classification import Accuracy, Precision, Recall

from anomalib.data import MVTec
from anomalib.engine import Engine
from anomalib.models import Padim

if __name__ == "__main__":
    model = Padim()
    data = MVTec()
    engine = Engine(image_metrics=["F1Score", Accuracy(task="binary")])
    engine.train(model, datamodule=data)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, would it be an idea to have an additional check like;

from torchmetrics.classification import Accuracy, Precision, Recall
import types

def instantiate_if_needed(metric, task="binary"):
    if isinstance(metric, types.FunctionType) or isinstance(metric, type):
        # If metric is a function or a class (not instantiated)
        return metric(task=task)
    else:
        # If metric is already instantiated
        return metric

Or do you think if this is overkill?

raise TypeError(msg)
if all(isinstance(metric, str) for metric in metrics):
return metric_collection_from_names(metrics, prefix)
return AnomalibMetricCollection(metrics, prefix)

return metric_collection_from_names(metrics, prefix)
if isinstance(metrics, Metric):
return AnomalibMetricCollection([metrics], prefix)

if isinstance(metrics, DictConfig | dict):
_validate_metrics_dict(metrics)
Expand Down
45 changes: 45 additions & 0 deletions tests/unit/metrics/test_create_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Test metrics collection creation."""

from torchmetrics.classification import Accuracy

from anomalib.metrics import AUPRO, create_metric_collection


def test_string_initialization() -> None:
"""Pass metrics as a list of string."""
metrics_list = ["AUROC", "AUPR"]
collection = create_metric_collection(metrics_list, prefix=None)
assert len(collection) == 2
assert "AUROC" in collection
assert "AUPR" in collection


def test_dict_initialization() -> None:
"""Pass metrics as a dictionary."""
metrics_dict = {
"PixelWiseAUROC": {
"class_path": "anomalib.metrics.AUROC",
"init_args": {},
},
"Precision": {
"class_path": "torchmetrics.Precision",
"init_args": {"task": "binary"},
},
}
collection = create_metric_collection(metrics_dict, prefix=None)
assert len(collection) == 2
assert "PixelWiseAUROC" in collection
assert "Precision" in collection


def test_metric_object_initialization() -> None:
"""Pass metrics as a list of metric objects."""
metrics_list = [AUPRO(), Accuracy(task="binary")]
collection = create_metric_collection(metrics_list, prefix=None)
assert len(collection) == 2
assert "AUPRO" in collection
assert "BinaryAccuracy" in collection

collection = create_metric_collection(AUPRO(), prefix=None)
assert len(collection) == 1
assert "AUPRO" in collection
Loading