From 23df070f11fc08ff01ab75e285461c906392c397 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 12 Dec 2024 16:48:21 +0100 Subject: [PATCH 01/22] apply transforms in pre-processor --- src/anomalib/data/dataclasses/generic.py | 17 +++ src/anomalib/pre_processing/pre_processing.py | 110 +++++++----------- .../pre_processing/utils/transform.py | 104 ----------------- .../pre_processing/test_pre_processing.py | 2 + 4 files changed, 62 insertions(+), 171 deletions(-) diff --git a/src/anomalib/data/dataclasses/generic.py b/src/anomalib/data/dataclasses/generic.py index 5f9dca9dc9..d5d8a51e69 100644 --- a/src/anomalib/data/dataclasses/generic.py +++ b/src/anomalib/data/dataclasses/generic.py @@ -17,7 +17,9 @@ import numpy as np import torch +from torch import tensor from torch.utils.data import default_collate +from torchvision.transforms.v2.functional import resize from torchvision.tv_tensors import Image, Mask, Video ImageT = TypeVar("ImageT", Image, Video, np.ndarray) @@ -656,5 +658,20 @@ def batch_size(self) -> int: def collate(cls: type["BatchIterateMixin"], items: list[ItemT]) -> "BatchIterateMixin": """Convert a list of DatasetItem objects to a Batch object.""" keys = [key for key, value in asdict(items[0]).items() if value is not None] + + # Check if all images have the same shape. If not, resize before collating + im_shapes = torch.vstack([tensor(item.image.shape) for item in items if item.image is not None])[..., 1:] + if torch.unique(im_shapes, dim=0).size(0) != 1: # check if batch has heterogeneous shapes + target_shape = im_shapes[ + torch.unravel_index(im_shapes.argmax(), im_shapes.shape)[0], + :, + ] # shape of image with largest H or W + for item in items: + for key in keys: + value = getattr(item, key) + if isinstance(value, Image | Mask): + setattr(item, key, resize(value, target_shape)) + + # collate the batch out_dict = {key: default_collate([getattr(item, key) for item in items]) for key in keys} return cls(**out_dict) diff --git a/src/anomalib/pre_processing/pre_processing.py b/src/anomalib/pre_processing/pre_processing.py index 27cffc7605..812540860b 100644 --- a/src/anomalib/pre_processing/pre_processing.py +++ b/src/anomalib/pre_processing/pre_processing.py @@ -3,27 +3,17 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import TYPE_CHECKING - import torch from lightning import Callback, LightningModule, Trainer -from lightning.pytorch.trainer.states import TrainerFn from torch import nn -from torch.utils.data import DataLoader from torchvision.transforms.v2 import Transform +from anomalib.data import Batch + from .utils.transform import ( - get_dataloaders_transforms, get_exportable_transform, - set_dataloaders_transforms, - set_datamodule_stage_transform, ) -if TYPE_CHECKING: - from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS - - from anomalib.data import AnomalibDataModule - class PreProcessor(nn.Module, Callback): """Anomalib pre-processor. @@ -109,63 +99,49 @@ def __init__( self.predict_transform = self.test_transform self.export_transform = get_exportable_transform(self.test_transform) - def setup_datamodule_transforms(self, datamodule: "AnomalibDataModule") -> None: - """Set up datamodule transforms.""" - # If PreProcessor has transforms, propagate them to datamodule - if any([self.train_transform, self.val_transform, self.test_transform]): - transforms = { - "fit": self.train_transform, - "val": self.val_transform, - "test": self.test_transform, - "predict": self.predict_transform, - } - - for stage, transform in transforms.items(): - if transform is not None: - set_datamodule_stage_transform(datamodule, transform, stage) - - def setup_dataloader_transforms(self, dataloaders: "EVAL_DATALOADERS | TRAIN_DATALOADERS") -> None: - """Set up dataloader transforms.""" - if isinstance(dataloaders, DataLoader): - dataloaders = [dataloaders] - - # If PreProcessor has transforms, propagate them to dataloaders - if any([self.train_transform, self.val_transform, self.test_transform]): - transforms = { - "train": self.train_transform, - "val": self.val_transform, - "test": self.test_transform, - } - set_dataloaders_transforms(dataloaders, transforms) - return - - # Try to get transforms from dataloaders - if dataloaders: - dataloaders_transforms = get_dataloaders_transforms(dataloaders) - if dataloaders_transforms: - self.train_transform = dataloaders_transforms.get("train") - self.val_transform = dataloaders_transforms.get("val") - self.test_transform = dataloaders_transforms.get("test") - self.predict_transform = self.test_transform - self.export_transform = get_exportable_transform(self.test_transform) - - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: - """Configure transforms at the start of each stage. - - Args: - trainer: The Lightning trainer. - pl_module: The Lightning module. - stage: The stage (e.g., 'fit', 'validate', 'test', 'predict'). - """ - stage = TrainerFn(stage).value # Ensure stage is str + def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Batch, batch_idx: int) -> None: + """Apply transforms to the batch of tensors during training.""" + del trainer, pl_module, batch_idx # Unused + if self.train_transform: + batch.image, batch.gt_mask = self.train_transform(batch.image, batch.gt_mask) + + def on_validation_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch: Batch, + batch_idx: int, + ) -> None: + """Apply transforms to the batch of tensors during validation.""" + del trainer, pl_module, batch_idx # Unused + if self.val_transform: + batch.image, batch.gt_mask = self.val_transform(batch.image, batch.gt_mask) - if hasattr(trainer, "datamodule"): - self.setup_datamodule_transforms(datamodule=trainer.datamodule) - elif hasattr(trainer, f"{stage}_dataloaders"): - dataloaders = getattr(trainer, f"{stage}_dataloaders") - self.setup_dataloader_transforms(dataloaders=dataloaders) + def on_test_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch: Batch, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Apply transforms to the batch of tensors during testing.""" + del trainer, pl_module, batch_idx, dataloader_idx # Unused + if self.test_transform: + batch.image, batch.gt_mask = self.test_transform(batch.image, batch.gt_mask) - super().setup(trainer, pl_module, stage) + def on_predict_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch: Batch, + batch_idx: int, + dataloader_idx: int = 0, + ) -> None: + """Apply transforms to the batch of tensors during prediction.""" + del trainer, pl_module, batch_idx, dataloader_idx # Unused + if self.predict_transform: + batch.image, batch.gt_mask = self.predict_transform(batch.image, batch.gt_mask) def forward(self, batch: torch.Tensor) -> torch.Tensor: """Apply transforms to the batch of tensors for inference. diff --git a/src/anomalib/pre_processing/utils/transform.py b/src/anomalib/pre_processing/utils/transform.py index 37eb1e9dd1..ce8de9ba28 100644 --- a/src/anomalib/pre_processing/utils/transform.py +++ b/src/anomalib/pre_processing/utils/transform.py @@ -3,115 +3,11 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from collections.abc import Sequence - -from torch.utils.data import DataLoader from torchvision.transforms.v2 import CenterCrop, Compose, Resize, Transform -from anomalib.data import AnomalibDataModule from anomalib.data.transforms import ExportableCenterCrop -def get_dataloaders_transforms(dataloaders: Sequence[DataLoader]) -> dict[str, Transform]: - """Get transforms from dataloaders. - - Args: - dataloaders: The dataloaders to get transforms from. - - Returns: - Dictionary mapping stages to their transforms. - """ - transforms: dict[str, Transform] = {} - stage_lookup = { - "fit": "train", - "validate": "val", - "test": "test", - "predict": "test", - } - - for dataloader in dataloaders: - if not hasattr(dataloader, "dataset") or not hasattr(dataloader.dataset, "transform"): - continue - - for stage in stage_lookup: - if hasattr(dataloader, f"{stage}_dataloader"): - transforms[stage_lookup[stage]] = dataloader.dataset.transform - - return transforms - - -def set_dataloaders_transforms(dataloaders: Sequence[DataLoader], transforms: dict[str, Transform | None]) -> None: - """Set transforms to dataloaders. - - Args: - dataloaders: The dataloaders to propagate transforms to. - transforms: Dictionary mapping stages to their transforms. - """ - stage_mapping = { - "fit": "train", - "validate": "val", - "test": "test", - "predict": "test", # predict uses test transform - } - - for loader in dataloaders: - if not hasattr(loader, "dataset"): - continue - - for stage in stage_mapping: - if hasattr(loader, f"{stage}_dataloader"): - transform = transforms.get(stage_mapping[stage]) - if transform is not None: - set_dataloader_transform([loader], transform) - - -def set_dataloader_transform(dataloader: DataLoader | Sequence[DataLoader], transform: Transform) -> None: - """Set a transform for a dataloader or list of dataloaders. - - Args: - dataloader: The dataloader(s) to set the transform for. - transform: The transform to set. - """ - if isinstance(dataloader, DataLoader): - if hasattr(dataloader.dataset, "transform"): - dataloader.dataset.transform = transform - elif isinstance(dataloader, Sequence): - for dl in dataloader: - set_dataloader_transform(dl, transform) - else: - msg = f"Unsupported dataloader type: {type(dataloader)}" - raise TypeError(msg) - - -def set_datamodule_stage_transform(datamodule: AnomalibDataModule, transform: Transform, stage: str) -> None: - """Set a transform for a specific stage in a AnomalibDataModule. - - Args: - datamodule: The AnomalibDataModule to set the transform for. - transform: The transform to set. - stage: The stage to set the transform for. - - Note: - The stage parameter maps to dataset attributes as follows: - - 'fit' -> 'train_data' - - 'validate' -> 'val_data' - - 'test' -> 'test_data' - - 'predict' -> 'test_data' - """ - stage_datasets = { - "fit": "train_data", - "validate": "val_data", - "test": "test_data", - "predict": "test_data", - } - - dataset_attr = stage_datasets.get(stage) - if dataset_attr and hasattr(datamodule, dataset_attr): - dataset = getattr(datamodule, dataset_attr) - if hasattr(dataset, "transform"): - dataset.transform = transform - - def get_exportable_transform(transform: Transform | None) -> Transform | None: """Get exportable transform. diff --git a/tests/unit/pre_processing/test_pre_processing.py b/tests/unit/pre_processing/test_pre_processing.py index 36394d54a3..fb10ba4b3b 100644 --- a/tests/unit/pre_processing/test_pre_processing.py +++ b/tests/unit/pre_processing/test_pre_processing.py @@ -91,6 +91,7 @@ def test_different_stage_transforms() -> None: assert isinstance(processed_batch, torch.Tensor) assert processed_batch.shape == (1, 3, 288, 288) + @pytest.skip def test_setup_transforms_from_dataloaders(self) -> None: """Test setup method when transforms are obtained from dataloaders.""" # Mock dataloader with dataset having a transform @@ -104,6 +105,7 @@ def test_setup_transforms_from_dataloaders(self) -> None: assert pre_processor.val_transform == self.common_transform assert pre_processor.test_transform == self.common_transform + @pytest.skip def test_setup_transforms_priority(self) -> None: """Test setup method prioritizes PreProcessor transforms over datamodule/dataloaders.""" # Mock datamodule From 5a1262297394b72a7a61111b0b21756b1e365174 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 16 Dec 2024 19:16:51 +0100 Subject: [PATCH 02/22] add augmentation arguments to datamodules --- src/anomalib/data/datamodules/base/image.py | 35 +++++++++++++--- .../data/datamodules/depth/folder_3d.py | 18 +++++++++ .../data/datamodules/depth/mvtec_3d.py | 18 +++++++++ src/anomalib/data/datamodules/image/btech.py | 17 ++++++++ .../data/datamodules/image/datumaro.py | 20 ++++++++-- src/anomalib/data/datamodules/image/folder.py | 18 +++++++++ .../data/datamodules/image/kolektor.py | 18 +++++++++ src/anomalib/data/datamodules/image/mvtec.py | 18 +++++++++ src/anomalib/data/datamodules/image/visa.py | 17 ++++++++ src/anomalib/data/datamodules/video/avenue.py | 17 ++++++++ .../data/datamodules/video/shanghaitech.py | 18 +++++++++ .../data/datamodules/video/ucsd_ped.py | 18 +++++++++ src/anomalib/data/datasets/base/depth.py | 12 +++--- src/anomalib/data/datasets/base/image.py | 10 ++--- src/anomalib/data/datasets/base/video.py | 16 ++++---- src/anomalib/data/datasets/depth/folder_3d.py | 6 +-- src/anomalib/data/datasets/depth/mvtec_3d.py | 6 +-- src/anomalib/data/datasets/image/btech.py | 4 +- src/anomalib/data/datasets/image/datumaro.py | 4 +- src/anomalib/data/datasets/image/folder.py | 4 +- src/anomalib/data/datasets/image/kolektor.py | 4 +- src/anomalib/data/datasets/image/mvtec.py | 6 +-- src/anomalib/data/datasets/image/visa.py | 4 +- src/anomalib/data/datasets/video/avenue.py | 6 +-- .../data/datasets/video/shanghaitech.py | 6 +-- src/anomalib/data/datasets/video/ucsd_ped.py | 6 +-- src/anomalib/data/utils/synthetic.py | 10 ++--- .../data/datamodule/depth/test_folder_3d.py | 2 + .../data/datamodule/depth/test_mvtec_3d.py | 2 + .../unit/data/datamodule/image/test_btech.py | 2 + .../data/datamodule/image/test_datumaro.py | 2 + .../unit/data/datamodule/image/test_folder.py | 2 + .../data/datamodule/image/test_kolektor.py | 2 + .../unit/data/datamodule/image/test_mvtec.py | 2 + tests/unit/data/datamodule/image/test_visa.py | 2 + .../unit/data/datamodule/video/test_avenue.py | 2 + .../datamodule/video/test_shanghaitech.py | 2 + .../data/datamodule/video/test_ucsdped.py | 2 + tests/unit/data/utils/test_synthetic.py | 4 +- .../pre_processing/test_pre_processing.py | 40 ------------------- .../pre_processing/utils/test_transform.py | 30 -------------- 41 files changed, 300 insertions(+), 132 deletions(-) diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 5c28cd4557..b45160322c 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -3,6 +3,7 @@ # Copyright (C) 2022-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import copy import logging from abc import ABC, abstractmethod from pathlib import Path @@ -12,6 +13,7 @@ from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data.dataloader import DataLoader +from torchvision.transforms.v2 import Transform from anomalib import TaskType from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label @@ -32,6 +34,14 @@ class AnomalibDataModule(LightningDataModule, ABC): train_batch_size (int): Batch size used by the train dataloader. eval_batch_size (int): Batch size used by the val and test dataloaders. num_workers (int): Number of workers used by the train, val and test dataloaders. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. val_split_mode (ValSplitMode): Determines how the validation split is obtained. Options: [none, same_as_test, from_test, synthetic] val_split_ratio (float): Fraction of the train or test images held our for validation. @@ -49,8 +59,12 @@ def __init__( train_batch_size: int, eval_batch_size: int, num_workers: int, - val_split_mode: ValSplitMode | str, - val_split_ratio: float, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, + val_split_mode: ValSplitMode | str | None = None, + val_split_ratio: float | None = None, test_split_mode: TestSplitMode | str | None = None, test_split_ratio: float | None = None, seed: int | None = None, @@ -60,11 +74,15 @@ def __init__( self.eval_batch_size = eval_batch_size self.num_workers = num_workers self.test_split_mode = TestSplitMode(test_split_mode) if test_split_mode else TestSplitMode.NONE - self.test_split_ratio = test_split_ratio + self.test_split_ratio = test_split_ratio or 0.5 self.val_split_mode = ValSplitMode(val_split_mode) - self.val_split_ratio = val_split_ratio + self.val_split_ratio = val_split_ratio or 0.5 self.seed = seed + self.train_augmentations = train_augmentations or augmentations + self.val_augmentations = val_augmentations or augmentations + self.test_augmentations = test_augmentations or augmentations + self.train_data: AnomalibDataset self.val_data: AnomalibDataset self.test_data: AnomalibDataset @@ -95,6 +113,13 @@ def setup(self, stage: str | None = None) -> None: # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer self._is_setup = True + if hasattr(self, "train_data"): + self.train_data.augmentations = self.train_augmentations + if hasattr(self, "val_data"): + self.val_data.augmentations = self.val_augmentations + if hasattr(self, "test_data"): + self.test_data.augmentations = self.test_augmentations + @abstractmethod def _setup(self, _stage: str | None = None) -> None: """Set up the datasets and perform dynamic subset splitting. @@ -175,7 +200,7 @@ def _create_val_split(self) -> None: ) elif self.val_split_mode == ValSplitMode.SAME_AS_TEST: # equal to test set - self.val_data = self.test_data + self.val_data = copy.deepcopy(self.test_data) elif self.val_split_mode == ValSplitMode.SYNTHETIC: # converted from random training sample self.train_data, normal_val_data = random_split(self.train_data, self.val_split_ratio, seed=self.seed) diff --git a/src/anomalib/data/datamodules/depth/folder_3d.py b/src/anomalib/data/datamodules/depth/folder_3d.py index f475c26bd8..566f5921a9 100644 --- a/src/anomalib/data/datamodules/depth/folder_3d.py +++ b/src/anomalib/data/datamodules/depth/folder_3d.py @@ -8,6 +8,8 @@ from pathlib import Path +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base.image import AnomalibDataModule from anomalib.data.datasets.depth.folder_3d import Folder3DDataset from anomalib.data.utils import Split, TestSplitMode, ValSplitMode @@ -46,6 +48,14 @@ class Folder3D(AnomalibDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. @@ -73,6 +83,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.2, val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, @@ -83,6 +97,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, test_split_mode=test_split_mode, test_split_ratio=test_split_ratio, val_split_mode=val_split_mode, diff --git a/src/anomalib/data/datamodules/depth/mvtec_3d.py b/src/anomalib/data/datamodules/depth/mvtec_3d.py index 400b1d3139..b6db30dba9 100644 --- a/src/anomalib/data/datamodules/depth/mvtec_3d.py +++ b/src/anomalib/data/datamodules/depth/mvtec_3d.py @@ -22,6 +22,8 @@ import logging from pathlib import Path +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base.image import AnomalibDataModule from anomalib.data.datasets.depth.mvtec_3d import MVTec3DDataset from anomalib.data.utils import DownloadInfo, Split, TestSplitMode, ValSplitMode, download_and_extract @@ -51,6 +53,14 @@ class MVTec3D(AnomalibDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. @@ -70,6 +80,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.2, val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, @@ -80,6 +94,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, test_split_mode=test_split_mode, test_split_ratio=test_split_ratio, val_split_mode=val_split_mode, diff --git a/src/anomalib/data/datamodules/image/btech.py b/src/anomalib/data/datamodules/image/btech.py index 4ec0527f16..f0b1fee82c 100644 --- a/src/anomalib/data/datamodules/image/btech.py +++ b/src/anomalib/data/datamodules/image/btech.py @@ -14,6 +14,7 @@ from pathlib import Path import cv2 +from torchvision.transforms.v2 import Transform from tqdm import tqdm from anomalib.data.datamodules.base.image import AnomalibDataModule @@ -43,6 +44,14 @@ class BTech(AnomalibDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. test_split_mode (TestSplitMode, optional): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. test_split_ratio (float, optional): Fraction of images from the train set that will be reserved for testing. @@ -99,6 +108,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.2, val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, @@ -109,6 +122,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, test_split_mode=test_split_mode, test_split_ratio=test_split_ratio, val_split_mode=val_split_mode, diff --git a/src/anomalib/data/datamodules/image/datumaro.py b/src/anomalib/data/datamodules/image/datumaro.py index fb37bc7ee7..b1a50533f7 100644 --- a/src/anomalib/data/datamodules/image/datumaro.py +++ b/src/anomalib/data/datamodules/image/datumaro.py @@ -8,6 +8,8 @@ from pathlib import Path +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base import AnomalibDataModule from anomalib.data.datasets.image.datumaro import DatumaroDataset from anomalib.data.utils import Split, TestSplitMode, ValSplitMode @@ -24,13 +26,15 @@ class Datumaro(AnomalibDataModule): Defaults to ``32``. num_workers (int): Number of workers for dataloaders. Defaults to ``8``. - image_size (tuple[int, int], optional): Size to which input images should be resized. + train_augmentations (Transform | None): Augmentations to apply dto the training images Defaults to ``None``. - transform (Transform, optional): Transforms that should be applied to the input images. + val_augmentations (Transform | None): Augmentations to apply to the validation images. Defaults to ``None``. - train_transform (Transform, optional): Transforms that should be applied to the input images during training. + test_augmentations (Transform | None): Augmentations to apply to the test images. Defaults to ``None``. - eval_transform (Transform, optional): Transforms that should be applied to the input images during evaluation. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. + image_size (tuple[int, int], optional): Size to which input images should be resized. Defaults to ``None``. test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. @@ -65,6 +69,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.5, val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, @@ -75,6 +83,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, val_split_mode=val_split_mode, val_split_ratio=val_split_ratio, test_split_mode=test_split_mode, diff --git a/src/anomalib/data/datamodules/image/folder.py b/src/anomalib/data/datamodules/image/folder.py index bd3c3fedd0..9802c55a31 100644 --- a/src/anomalib/data/datamodules/image/folder.py +++ b/src/anomalib/data/datamodules/image/folder.py @@ -9,6 +9,8 @@ from collections.abc import Sequence from pathlib import Path +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base.image import AnomalibDataModule from anomalib.data.datasets.image.folder import FolderDataset from anomalib.data.utils import Split, TestSplitMode, ValSplitMode @@ -42,6 +44,14 @@ class Folder(AnomalibDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. @@ -119,6 +129,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.2, val_split_mode: ValSplitMode | str = ValSplitMode.FROM_TEST, @@ -138,6 +152,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, test_split_mode=test_split_mode, test_split_ratio=test_split_ratio, val_split_mode=val_split_mode, diff --git a/src/anomalib/data/datamodules/image/kolektor.py b/src/anomalib/data/datamodules/image/kolektor.py index fe767c3a94..f592aa779d 100644 --- a/src/anomalib/data/datamodules/image/kolektor.py +++ b/src/anomalib/data/datamodules/image/kolektor.py @@ -20,6 +20,8 @@ import logging from pathlib import Path +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base.image import AnomalibDataModule from anomalib.data.datasets.image.kolektor import KolektorDataset from anomalib.data.utils import DownloadInfo, Split, TestSplitMode, ValSplitMode, download_and_extract @@ -45,6 +47,14 @@ class Kolektor(AnomalibDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR`` test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. @@ -63,6 +73,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.2, val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, @@ -73,6 +87,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, test_split_mode=test_split_mode, test_split_ratio=test_split_ratio, val_split_mode=val_split_mode, diff --git a/src/anomalib/data/datamodules/image/mvtec.py b/src/anomalib/data/datamodules/image/mvtec.py index 9e7b2fce89..db84374252 100644 --- a/src/anomalib/data/datamodules/image/mvtec.py +++ b/src/anomalib/data/datamodules/image/mvtec.py @@ -28,6 +28,8 @@ import logging from pathlib import Path +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base.image import AnomalibDataModule from anomalib.data.datasets.image.mvtec import MVTecDataset from anomalib.data.utils import DownloadInfo, Split, TestSplitMode, ValSplitMode, download_and_extract @@ -57,6 +59,14 @@ class MVTec(AnomalibDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. @@ -105,6 +115,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.2, val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, @@ -115,6 +129,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, test_split_mode=test_split_mode, test_split_ratio=test_split_ratio, val_split_mode=val_split_mode, diff --git a/src/anomalib/data/datamodules/image/visa.py b/src/anomalib/data/datamodules/image/visa.py index 553d0dcc03..0b8b54238e 100644 --- a/src/anomalib/data/datamodules/image/visa.py +++ b/src/anomalib/data/datamodules/image/visa.py @@ -28,6 +28,7 @@ from pathlib import Path import cv2 +from torchvision.transforms.v2 import Transform from anomalib.data.datamodules.base.image import AnomalibDataModule from anomalib.data.datasets.image.visa import VisaDataset @@ -56,6 +57,14 @@ class Visa(AnomalibDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. test_split_mode (TestSplitMode): Setting that determines how the testing subset is obtained. Defaults to ``TestSplitMode.FROM_DIR``. test_split_ratio (float): Fraction of images from the train set that will be reserved for testing. @@ -75,6 +84,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, test_split_mode: TestSplitMode | str = TestSplitMode.FROM_DIR, test_split_ratio: float = 0.2, val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, @@ -85,6 +98,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, test_split_mode=test_split_mode, test_split_ratio=test_split_ratio, val_split_mode=val_split_mode, diff --git a/src/anomalib/data/datamodules/video/avenue.py b/src/anomalib/data/datamodules/video/avenue.py index 446b4b6c37..deb421e778 100644 --- a/src/anomalib/data/datamodules/video/avenue.py +++ b/src/anomalib/data/datamodules/video/avenue.py @@ -21,6 +21,7 @@ import cv2 import scipy.io +from torchvision.transforms.v2 import Transform from anomalib.data.datamodules.base.video import AnomalibVideoDataModule from anomalib.data.datasets.base.video import VideoTargetFrame @@ -61,6 +62,14 @@ class Avenue(AnomalibVideoDataModule): Defaults to ``32``. num_workers (int, optional): Number of workers. Defaults to ``8``. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. Defaults to ``ValSplitMode.FROM_TEST``. val_split_ratio (float): Fraction of train or test images that will be reserved for validation. @@ -124,6 +133,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, val_split_mode: ValSplitMode | str = ValSplitMode.SAME_AS_TEST, val_split_ratio: float = 0.5, seed: int | None = None, @@ -132,6 +145,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, val_split_mode=val_split_mode, val_split_ratio=val_split_ratio, seed=seed, diff --git a/src/anomalib/data/datamodules/video/shanghaitech.py b/src/anomalib/data/datamodules/video/shanghaitech.py index f5e5cd0036..a04c362add 100644 --- a/src/anomalib/data/datamodules/video/shanghaitech.py +++ b/src/anomalib/data/datamodules/video/shanghaitech.py @@ -20,6 +20,8 @@ from pathlib import Path from shutil import move +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base.video import AnomalibVideoDataModule from anomalib.data.datasets.base.video import VideoTargetFrame from anomalib.data.datasets.video.shanghaitech import ShanghaiTechDataset @@ -47,6 +49,14 @@ class ShanghaiTech(AnomalibVideoDataModule): train_batch_size (int, optional): Training batch size. Defaults to 32. eval_batch_size (int, optional): Test batch size. Defaults to 32. num_workers (int, optional): Number of workers. Defaults to 8. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. val_split_ratio (float): Fraction of train or test images that will be reserved for validation. seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. @@ -62,6 +72,10 @@ def __init__( train_batch_size: int = 32, eval_batch_size: int = 32, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST, val_split_ratio: float = 0.5, seed: int | None = None, @@ -70,6 +84,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, val_split_mode=val_split_mode, val_split_ratio=val_split_ratio, seed=seed, diff --git a/src/anomalib/data/datamodules/video/ucsd_ped.py b/src/anomalib/data/datamodules/video/ucsd_ped.py index e08bfd1ca6..1ce9838edf 100644 --- a/src/anomalib/data/datamodules/video/ucsd_ped.py +++ b/src/anomalib/data/datamodules/video/ucsd_ped.py @@ -7,6 +7,8 @@ from pathlib import Path from shutil import move +from torchvision.transforms.v2 import Transform + from anomalib.data.datamodules.base.video import AnomalibVideoDataModule from anomalib.data.datasets.base.video import VideoTargetFrame from anomalib.data.datasets.video.ucsd_ped import UCSDpedDataset @@ -33,6 +35,14 @@ class UCSDped(AnomalibVideoDataModule): train_batch_size (int, optional): Training batch size. Defaults to 32. eval_batch_size (int, optional): Test batch size. Defaults to 32. num_workers (int, optional): Number of workers. Defaults to 8. + train_augmentations (Transform | None): Augmentations to apply dto the training images + Defaults to ``None``. + val_augmentations (Transform | None): Augmentations to apply to the validation images. + Defaults to ``None``. + test_augmentations (Transform | None): Augmentations to apply to the test images. + Defaults to ``None``. + augmentations (Transform | None): General augmentations to apply if stage-specific + augmentations are not provided. val_split_mode (ValSplitMode): Setting that determines how the validation subset is obtained. val_split_ratio (float): Fraction of train or test images that will be reserved for validation. seed (int | None, optional): Seed which may be set to a fixed value for reproducibility. @@ -48,6 +58,10 @@ def __init__( train_batch_size: int = 8, eval_batch_size: int = 8, num_workers: int = 8, + train_augmentations: Transform | None = None, + val_augmentations: Transform | None = None, + test_augmentations: Transform | None = None, + augmentations: Transform | None = None, val_split_mode: ValSplitMode = ValSplitMode.SAME_AS_TEST, val_split_ratio: float = 0.5, seed: int | None = None, @@ -56,6 +70,10 @@ def __init__( train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, num_workers=num_workers, + train_augmentations=train_augmentations, + val_augmentations=val_augmentations, + test_augmentations=test_augmentations, + augmentations=augmentations, val_split_mode=val_split_mode, val_split_ratio=val_split_ratio, seed=seed, diff --git a/src/anomalib/data/datasets/base/depth.py b/src/anomalib/data/datasets/base/depth.py index 5dd4683b6c..83ec1c4774 100644 --- a/src/anomalib/data/datasets/base/depth.py +++ b/src/anomalib/data/datasets/base/depth.py @@ -23,14 +23,14 @@ class AnomalibDepthDataset(AnomalibDataset, ABC): """Base depth anomalib dataset class. Args: - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. """ - def __init__(self, transform: Transform | None = None) -> None: - super().__init__(transform) + def __init__(self, augmentations: Transform | None = None) -> None: + super().__init__(augmentations=augmentations) - self.transform = transform + self.augmentations = augmentations def __getitem__(self, index: int) -> DepthItem: """Return rgb image, depth image and mask. @@ -52,7 +52,7 @@ def __getitem__(self, index: int) -> DepthItem: if self.task == TaskType.CLASSIFICATION: item["image"], item["depth_image"] = ( - self.transform(image, depth_image) if self.transform else (image, depth_image) + self.augmentations(image, depth_image) if self.augmentations else (image, depth_image) ) elif self.task == TaskType.SEGMENTATION: # Only Anomalous (1) images have masks in anomaly datasets @@ -63,7 +63,7 @@ def __getitem__(self, index: int) -> DepthItem: else Mask(to_tensor(Image.open(mask_path)).squeeze()) ) item["image"], item["depth_image"], item["mask"] = ( - self.transform(image, depth_image, mask) if self.transform else (image, depth_image, mask) + self.augmentations(image, depth_image, mask) if self.augmentations else (image, depth_image, mask) ) item["mask_path"] = mask_path diff --git a/src/anomalib/data/datasets/base/image.py b/src/anomalib/data/datasets/base/image.py index 9bc8c45e74..42e6c79294 100644 --- a/src/anomalib/data/datasets/base/image.py +++ b/src/anomalib/data/datasets/base/image.py @@ -52,13 +52,13 @@ class AnomalibDataset(Dataset, ABC): Args: task (str): Task type, either 'classification' or 'segmentation' - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. """ - def __init__(self, transform: Transform | None = None) -> None: + def __init__(self, augmentations: Transform | None = None) -> None: super().__init__() - self.transform = transform + self.augmentations = augmentations self._samples: DataFrame | None = None self._category: str | None = None @@ -167,7 +167,7 @@ def __getitem__(self, index: int) -> DatasetItem: item = {"image_path": image_path, "gt_label": label_index} if self.task == TaskType.CLASSIFICATION: - item["image"] = self.transform(image) if self.transform else image + item["image"] = self.augmentations(image) if self.augmentations else image elif self.task == TaskType.SEGMENTATION: # Only Anomalous (1) images have masks in anomaly datasets # Therefore, create empty mask for Normal (0) images. @@ -176,7 +176,7 @@ def __getitem__(self, index: int) -> DatasetItem: if label_index == LabelName.NORMAL else read_mask(mask_path, as_tensor=True) ) - item["image"], item["gt_mask"] = self.transform(image, mask) if self.transform else (image, mask) + item["image"], item["gt_mask"] = self.augmentations(image, mask) if self.augmentations else (image, mask) else: msg = f"Unknown task type: {self.task}" diff --git a/src/anomalib/data/datasets/base/video.py b/src/anomalib/data/datasets/base/video.py index 4b8366aae4..28e1796767 100644 --- a/src/anomalib/data/datasets/base/video.py +++ b/src/anomalib/data/datasets/base/video.py @@ -37,7 +37,7 @@ class AnomalibVideoDataset(AnomalibDataset, ABC): Args: clip_length_in_frames (int): Number of video frames in each clip. frames_between_clips (int): Number of frames between each consecutive video clip. - transform (Transform, optional): Transforms that should be applied to the input clips. + augmentations (Transform, optional): Augmentations that should be applied to the input clips. Defaults to ``None``. target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. Defaults to ``VideoTargetFrame.LAST``. @@ -47,14 +47,14 @@ def __init__( self, clip_length_in_frames: int, frames_between_clips: int, - transform: Transform | None = None, + augmentations: Transform | None = None, target_frame: VideoTargetFrame = VideoTargetFrame.LAST, ) -> None: - super().__init__(transform) + super().__init__(augmentations=augmentations) self.clip_length_in_frames = clip_length_in_frames self.frames_between_clips = frames_between_clips - self.transform = transform + self.augmentations = augmentations self.indexer: ClipsIndexer | None = None self.indexer_cls: Callable | None = None @@ -152,11 +152,11 @@ def __getitem__(self, index: int) -> VideoItem: # apply transforms if item.gt_mask is not None: - if self.transform: - item.image, item.gt_mask = self.transform(item.image, Mask(item.gt_mask)) + if self.augmentations: + item.image, item.gt_mask = self.augmentations(item.image, Mask(item.gt_mask)) item.gt_label = torch.Tensor([1 in frame for frame in item.gt_mask]).int().squeeze(0) - elif self.transform: - item.image = self.transform(item.image) + elif self.augmentations: + item.image = self.augmentations(item.image) # squeeze temporal dimensions in case clip length is 1 item.image = item.image.squeeze(0) diff --git a/src/anomalib/data/datasets/depth/folder_3d.py b/src/anomalib/data/datasets/depth/folder_3d.py index 0e5247c7bc..a1510de834 100644 --- a/src/anomalib/data/datasets/depth/folder_3d.py +++ b/src/anomalib/data/datasets/depth/folder_3d.py @@ -43,7 +43,7 @@ class Folder3DDataset(AnomalibDepthDataset): normal_test_depth_dir (str | Path | None, optional): Path to the directory containing normal depth images for the test dataset. Normal test images will be a split of `normal_dir` if `None`. Defaults to ``None``. - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. split (str | Split | None): Fixed subset split that follows from folder structure on file system. Choose from [Split.FULL, Split.TRAIN, Split.TEST] @@ -63,11 +63,11 @@ def __init__( normal_depth_dir: str | Path | None = None, abnormal_depth_dir: str | Path | None = None, normal_test_depth_dir: str | Path | None = None, - transform: Transform | None = None, + augmentations: Transform | None = None, split: str | Split | None = None, extensions: tuple[str, ...] | None = None, ) -> None: - super().__init__(transform) + super().__init__(augmentations=augmentations) self._name = name self.split = split diff --git a/src/anomalib/data/datasets/depth/mvtec_3d.py b/src/anomalib/data/datasets/depth/mvtec_3d.py index 6dd8ed3752..01045f96e4 100644 --- a/src/anomalib/data/datasets/depth/mvtec_3d.py +++ b/src/anomalib/data/datasets/depth/mvtec_3d.py @@ -41,7 +41,7 @@ class MVTec3DDataset(AnomalibDepthDataset): Defaults to ``"./datasets/MVTec3D"``. category (str): Sub-category of the dataset, e.g. 'bagel' Defaults to ``"bagel"``. - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST Defaults to ``None``. @@ -51,10 +51,10 @@ def __init__( self, root: Path | str = "./datasets/MVTec3D", category: str = "bagel", - transform: Transform | None = None, + augmentations: Transform | None = None, split: str | Split | None = None, ) -> None: - super().__init__(transform=transform) + super().__init__(augmentations=augmentations) self.root_category = Path(root) / Path(category) self.split = split diff --git a/src/anomalib/data/datasets/image/btech.py b/src/anomalib/data/datasets/image/btech.py index 3078c99e12..b5ed37afe7 100644 --- a/src/anomalib/data/datasets/image/btech.py +++ b/src/anomalib/data/datasets/image/btech.py @@ -65,10 +65,10 @@ def __init__( self, root: str | Path, category: str, - transform: Transform | None = None, + augmentations: Transform | None = None, split: str | Split | None = None, ) -> None: - super().__init__(transform) + super().__init__(augmentations=augmentations) self.root_category = Path(root) / category self.split = split diff --git a/src/anomalib/data/datasets/image/datumaro.py b/src/anomalib/data/datasets/image/datumaro.py index 9335f0a4b4..6d4438a231 100644 --- a/src/anomalib/data/datasets/image/datumaro.py +++ b/src/anomalib/data/datasets/image/datumaro.py @@ -119,9 +119,9 @@ class DatumaroDataset(AnomalibDataset): def __init__( self, root: str | Path, - transform: Transform | None = None, + augmentations: Transform | None = None, split: str | Split | None = None, ) -> None: - super().__init__(transform) + super().__init__(augmentations=augmentations) self.split = split self.samples = make_datumaro_dataset(root, split) diff --git a/src/anomalib/data/datasets/image/folder.py b/src/anomalib/data/datasets/image/folder.py index 08e01d85c2..d9b3be6e36 100644 --- a/src/anomalib/data/datasets/image/folder.py +++ b/src/anomalib/data/datasets/image/folder.py @@ -68,7 +68,7 @@ def __init__( self, name: str, normal_dir: str | Path | Sequence[str | Path], - transform: Transform | None = None, + augmentations: Transform | None = None, root: str | Path | None = None, abnormal_dir: str | Path | Sequence[str | Path] | None = None, normal_test_dir: str | Path | Sequence[str | Path] | None = None, @@ -76,7 +76,7 @@ def __init__( split: str | Split | None = None, extensions: tuple[str, ...] | None = None, ) -> None: - super().__init__(transform) + super().__init__(augmentations=augmentations) self._name = name self.split = split diff --git a/src/anomalib/data/datasets/image/kolektor.py b/src/anomalib/data/datasets/image/kolektor.py index 410d2191cf..062e0f7978 100644 --- a/src/anomalib/data/datasets/image/kolektor.py +++ b/src/anomalib/data/datasets/image/kolektor.py @@ -46,10 +46,10 @@ class KolektorDataset(AnomalibDataset): def __init__( self, root: Path | str = "./datasets/kolektor", - transform: Transform | None = None, + augmentations: Transform | None = None, split: str | Split | None = None, ) -> None: - super().__init__(transform=transform) + super().__init__(augmentations=augmentations) self.root = root self.split = split diff --git a/src/anomalib/data/datasets/image/mvtec.py b/src/anomalib/data/datasets/image/mvtec.py index c07cdf34e4..cb133fd605 100644 --- a/src/anomalib/data/datasets/image/mvtec.py +++ b/src/anomalib/data/datasets/image/mvtec.py @@ -63,7 +63,7 @@ class MVTecDataset(AnomalibDataset): Defaults to ``./datasets/MVTec``. category (str): Sub-category of the dataset, e.g. 'bottle' Defaults to ``bottle``. - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. split (str | Split | None): Split of the dataset, usually Split.TRAIN or Split.TEST Defaults to ``None``. @@ -107,10 +107,10 @@ def __init__( self, root: Path | str = "./datasets/MVTec", category: str = "bottle", - transform: Transform | None = None, + augmentations: Transform | None = None, split: str | Split | None = None, ) -> None: - super().__init__(transform=transform) + super().__init__(augmentations=augmentations) self.root_category = Path(root) / Path(category) self.category = category diff --git a/src/anomalib/data/datasets/image/visa.py b/src/anomalib/data/datasets/image/visa.py index 70ee5352aa..05f1d6ffa8 100644 --- a/src/anomalib/data/datasets/image/visa.py +++ b/src/anomalib/data/datasets/image/visa.py @@ -82,10 +82,10 @@ def __init__( self, root: str | Path, category: str, - transform: Transform | None = None, + augmentations: Transform | None = None, split: str | Split | None = None, ) -> None: - super().__init__(transform=transform) + super().__init__(augmentations=augmentations) self.root_category = Path(root) / category self.split = split diff --git a/src/anomalib/data/datasets/video/avenue.py b/src/anomalib/data/datasets/video/avenue.py index 03c07404a5..7a91ef90f5 100644 --- a/src/anomalib/data/datasets/video/avenue.py +++ b/src/anomalib/data/datasets/video/avenue.py @@ -45,7 +45,7 @@ class AvenueDataset(AnomalibVideoDataset): Defaults to ``1``. target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. Defaults to ``VideoTargetFrame.LAST``. - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. Examples: @@ -92,14 +92,14 @@ def __init__( gt_dir: Path | str = "./datasets/avenue/ground_truth_demo", clip_length_in_frames: int = 2, frames_between_clips: int = 1, - transform: Transform | None = None, + augmentations: Transform | None = None, target_frame: VideoTargetFrame = VideoTargetFrame.LAST, ) -> None: super().__init__( clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, target_frame=target_frame, - transform=transform, + augmentations=augmentations, ) self.root = root if isinstance(root, Path) else Path(root) diff --git a/src/anomalib/data/datasets/video/shanghaitech.py b/src/anomalib/data/datasets/video/shanghaitech.py index 424a13e9e6..04b0b4108d 100644 --- a/src/anomalib/data/datasets/video/shanghaitech.py +++ b/src/anomalib/data/datasets/video/shanghaitech.py @@ -40,7 +40,7 @@ class ShanghaiTechDataset(AnomalibVideoDataset): clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. """ @@ -52,13 +52,13 @@ def __init__( clip_length_in_frames: int = 2, frames_between_clips: int = 1, target_frame: VideoTargetFrame = VideoTargetFrame.LAST, - transform: Transform | None = None, + augmentations: Transform | None = None, ) -> None: super().__init__( clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, target_frame=target_frame, - transform=transform, + augmentations=augmentations, ) self.root = Path(root) diff --git a/src/anomalib/data/datasets/video/ucsd_ped.py b/src/anomalib/data/datasets/video/ucsd_ped.py index 5a619be3f1..7835e75b10 100644 --- a/src/anomalib/data/datasets/video/ucsd_ped.py +++ b/src/anomalib/data/datasets/video/ucsd_ped.py @@ -31,7 +31,7 @@ class UCSDpedDataset(AnomalibVideoDataset): clip_length_in_frames (int, optional): Number of video frames in each clip. frames_between_clips (int, optional): Number of frames between each consecutive video clip. target_frame (VideoTargetFrame): Specifies the target frame in the video clip, used for ground truth retrieval. - transform (Transform, optional): Transforms that should be applied to the input images. + augmentations (Transform, optional): Augmentations that should be applied to the input images. Defaults to ``None``. """ @@ -43,13 +43,13 @@ def __init__( clip_length_in_frames: int = 2, frames_between_clips: int = 10, target_frame: VideoTargetFrame = VideoTargetFrame.LAST, - transform: Transform | None = None, + augmentations: Transform | None = None, ) -> None: super().__init__( clip_length_in_frames=clip_length_in_frames, frames_between_clips=frames_between_clips, target_frame=target_frame, - transform=transform, + augmentations=augmentations, ) self.root_category = Path(root) / category diff --git a/src/anomalib/data/utils/synthetic.py b/src/anomalib/data/utils/synthetic.py index c4b52d5b35..51dcd9372f 100644 --- a/src/anomalib/data/utils/synthetic.py +++ b/src/anomalib/data/utils/synthetic.py @@ -16,7 +16,7 @@ import cv2 import pandas as pd from pandas import DataFrame, Series -from torchvision.transforms.v2 import Compose +from torchvision.transforms.v2 import Transform from anomalib.data.datasets.base.image import AnomalibDataset from anomalib.data.utils import Split, read_image @@ -113,12 +113,12 @@ class SyntheticAnomalyDataset(AnomalibDataset): """Dataset which reads synthetically generated anomalous images from a temporary folder. Args: - transform (A.Compose): Transform object describing the transforms that are applied to the inputs. + augmentations (Transform | None): Transform object describing the input data augmentations. source_samples (DataFrame): Normal samples to which the anomalous augmentations will be applied. """ - def __init__(self, transform: Compose, source_samples: DataFrame) -> None: - super().__init__(transform) + def __init__(self, augmentations: Transform | None, source_samples: DataFrame) -> None: + super().__init__(augmentations=augmentations) self.source_samples = source_samples @@ -145,7 +145,7 @@ def from_dataset(cls: type["SyntheticAnomalyDataset"], dataset: AnomalibDataset) dataset (AnomalibDataset): Dataset consisting of only normal images that will be converrted to a synthetic anomalous dataset with a 50/50 normal anomalous split. """ - return cls(transform=dataset.transform, source_samples=dataset.samples) + return cls(augmentations=dataset.augmentations, source_samples=dataset.samples) def __copy__(self) -> "SyntheticAnomalyDataset": """Return a shallow copy of the dataset object and prevents cleanup when original object is deleted.""" diff --git a/tests/unit/data/datamodule/depth/test_folder_3d.py b/tests/unit/data/datamodule/depth/test_folder_3d.py index 71adef7b12..9deec32b9c 100644 --- a/tests/unit/data/datamodule/depth/test_folder_3d.py +++ b/tests/unit/data/datamodule/depth/test_folder_3d.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import Folder3D from tests.unit.data.datamodule.base.depth import _TestAnomalibDepthDatamodule @@ -31,6 +32,7 @@ def datamodule(dataset_path: Path) -> Folder3D: train_batch_size=4, eval_batch_size=4, num_workers=0, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() _datamodule.setup() diff --git a/tests/unit/data/datamodule/depth/test_mvtec_3d.py b/tests/unit/data/datamodule/depth/test_mvtec_3d.py index 2a90822763..f07266c56a 100644 --- a/tests/unit/data/datamodule/depth/test_mvtec_3d.py +++ b/tests/unit/data/datamodule/depth/test_mvtec_3d.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import MVTec3D from tests.unit.data.datamodule.base.depth import _TestAnomalibDepthDatamodule @@ -24,6 +25,7 @@ def datamodule(dataset_path: Path) -> MVTec3D: train_batch_size=4, eval_batch_size=4, num_workers=0, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() _datamodule.setup() diff --git a/tests/unit/data/datamodule/image/test_btech.py b/tests/unit/data/datamodule/image/test_btech.py index fb559641c1..6dcb7969a5 100644 --- a/tests/unit/data/datamodule/image/test_btech.py +++ b/tests/unit/data/datamodule/image/test_btech.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import BTech from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule @@ -23,6 +24,7 @@ def datamodule(dataset_path: Path) -> BTech: category="dummy", train_batch_size=4, eval_batch_size=4, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() diff --git a/tests/unit/data/datamodule/image/test_datumaro.py b/tests/unit/data/datamodule/image/test_datumaro.py index e10009a71c..9b527bd864 100644 --- a/tests/unit/data/datamodule/image/test_datumaro.py +++ b/tests/unit/data/datamodule/image/test_datumaro.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import Datumaro from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule @@ -22,6 +23,7 @@ def datamodule(dataset_path: Path) -> Datumaro: root=dataset_path / "datumaro", train_batch_size=4, eval_batch_size=4, + augmentations=Resize((256, 256)), ) _datamodule.setup() diff --git a/tests/unit/data/datamodule/image/test_folder.py b/tests/unit/data/datamodule/image/test_folder.py index e564b5a5e3..9c32239008 100644 --- a/tests/unit/data/datamodule/image/test_folder.py +++ b/tests/unit/data/datamodule/image/test_folder.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import Folder from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule @@ -35,6 +36,7 @@ def datamodule(dataset_path: Path) -> Folder: train_batch_size=4, eval_batch_size=4, num_workers=0, + augmentations=Resize((256, 256)), ) _datamodule.setup() diff --git a/tests/unit/data/datamodule/image/test_kolektor.py b/tests/unit/data/datamodule/image/test_kolektor.py index 3d6b896d50..b0456c05fd 100644 --- a/tests/unit/data/datamodule/image/test_kolektor.py +++ b/tests/unit/data/datamodule/image/test_kolektor.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import Kolektor from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule @@ -22,6 +23,7 @@ def datamodule(dataset_path: Path) -> Kolektor: root=dataset_path / "kolektor", train_batch_size=4, eval_batch_size=4, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() diff --git a/tests/unit/data/datamodule/image/test_mvtec.py b/tests/unit/data/datamodule/image/test_mvtec.py index 8f40c9e38a..537fa9c4e0 100644 --- a/tests/unit/data/datamodule/image/test_mvtec.py +++ b/tests/unit/data/datamodule/image/test_mvtec.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import MVTec from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule @@ -23,6 +24,7 @@ def datamodule(dataset_path: Path) -> MVTec: category="dummy", train_batch_size=4, eval_batch_size=4, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() _datamodule.setup() diff --git a/tests/unit/data/datamodule/image/test_visa.py b/tests/unit/data/datamodule/image/test_visa.py index b24b1d42c0..5f3968b531 100644 --- a/tests/unit/data/datamodule/image/test_visa.py +++ b/tests/unit/data/datamodule/image/test_visa.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import Visa from tests.unit.data.datamodule.base.image import _TestAnomalibImageDatamodule @@ -24,6 +25,7 @@ def datamodule(dataset_path: Path) -> Visa: train_batch_size=4, eval_batch_size=4, num_workers=0, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() _datamodule.setup() diff --git a/tests/unit/data/datamodule/video/test_avenue.py b/tests/unit/data/datamodule/video/test_avenue.py index f63e240e15..e7c3e8546e 100644 --- a/tests/unit/data/datamodule/video/test_avenue.py +++ b/tests/unit/data/datamodule/video/test_avenue.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import Avenue from tests.unit.data.datamodule.base.video import _TestAnomalibVideoDatamodule @@ -31,6 +32,7 @@ def datamodule(dataset_path: Path, clip_length_in_frames: int) -> Avenue: num_workers=0, train_batch_size=4, eval_batch_size=4, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() diff --git a/tests/unit/data/datamodule/video/test_shanghaitech.py b/tests/unit/data/datamodule/video/test_shanghaitech.py index e1dc1ba3c3..dfee8ca519 100644 --- a/tests/unit/data/datamodule/video/test_shanghaitech.py +++ b/tests/unit/data/datamodule/video/test_shanghaitech.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import ShanghaiTech from tests.unit.data.datamodule.base.video import _TestAnomalibVideoDatamodule @@ -31,6 +32,7 @@ def datamodule(dataset_path: Path, clip_length_in_frames: int) -> ShanghaiTech: train_batch_size=4, eval_batch_size=4, num_workers=0, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() diff --git a/tests/unit/data/datamodule/video/test_ucsdped.py b/tests/unit/data/datamodule/video/test_ucsdped.py index 3da6c076d1..f55347c3f2 100644 --- a/tests/unit/data/datamodule/video/test_ucsdped.py +++ b/tests/unit/data/datamodule/video/test_ucsdped.py @@ -6,6 +6,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data import UCSDped from tests.unit.data.datamodule.base.video import _TestAnomalibVideoDatamodule @@ -31,6 +32,7 @@ def datamodule(dataset_path: Path, clip_length_in_frames: int) -> UCSDped: train_batch_size=4, eval_batch_size=4, num_workers=0, + augmentations=Resize((256, 256)), ) _datamodule.prepare_data() _datamodule.setup() diff --git a/tests/unit/data/utils/test_synthetic.py b/tests/unit/data/utils/test_synthetic.py index 09cb77e777..5bab62f0a9 100644 --- a/tests/unit/data/utils/test_synthetic.py +++ b/tests/unit/data/utils/test_synthetic.py @@ -7,6 +7,7 @@ from pathlib import Path import pytest +from torchvision.transforms.v2 import Resize from anomalib.data.datasets.image.folder import FolderDataset from anomalib.data.utils.synthetic import SyntheticAnomalyDataset @@ -23,6 +24,7 @@ def folder_dataset(dataset_path: Path) -> FolderDataset: normal_test_dir="test/good", mask_dir="ground_truth/bad", split="train", + augmentations=Resize((256, 256)), ) @@ -36,7 +38,7 @@ def synthetic_dataset(folder_dataset: FolderDataset) -> SyntheticAnomalyDataset: def synthetic_dataset_from_samples(folder_dataset: FolderDataset) -> SyntheticAnomalyDataset: """Fixture that returns a SyntheticAnomalyDataset instance.""" return SyntheticAnomalyDataset( - transform=folder_dataset.transform, + augmentations=folder_dataset.augmentations, source_samples=folder_dataset.samples, ) diff --git a/tests/unit/pre_processing/test_pre_processing.py b/tests/unit/pre_processing/test_pre_processing.py index fb10ba4b3b..8962a328ef 100644 --- a/tests/unit/pre_processing/test_pre_processing.py +++ b/tests/unit/pre_processing/test_pre_processing.py @@ -3,11 +3,8 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock - import pytest import torch -from torch.utils.data import DataLoader from torchvision.transforms.v2 import Compose, Resize, ToDtype, ToImage from torchvision.tv_tensors import Image, Mask @@ -90,40 +87,3 @@ def test_different_stage_transforms() -> None: processed_batch = pre_processor.test_transform(test_batch.image) assert isinstance(processed_batch, torch.Tensor) assert processed_batch.shape == (1, 3, 288, 288) - - @pytest.skip - def test_setup_transforms_from_dataloaders(self) -> None: - """Test setup method when transforms are obtained from dataloaders.""" - # Mock dataloader with dataset having a transform - dataloader = MagicMock() - dataloader.dataset.transform = self.common_transform - - pre_processor = PreProcessor() - pre_processor.setup_dataloader_transforms(dataloaders=[dataloader]) - - assert pre_processor.train_transform == self.common_transform - assert pre_processor.val_transform == self.common_transform - assert pre_processor.test_transform == self.common_transform - - @pytest.skip - def test_setup_transforms_priority(self) -> None: - """Test setup method prioritizes PreProcessor transforms over datamodule/dataloaders.""" - # Mock datamodule - datamodule = MagicMock() - datamodule.train_transform = Compose([Resize((128, 128)), ToImage(), ToDtype(torch.float32, scale=True)]) - datamodule.eval_transform = Compose([Resize((128, 128)), ToImage(), ToDtype(torch.float32, scale=True)]) - - # Mock dataloader - dataset_mock = MagicMock() - dataset_mock.transform = Compose([Resize((64, 64)), ToImage(), ToDtype(torch.float32, scale=True)]) - dataloader = MagicMock(spec=DataLoader) - dataloader.dataset = dataset_mock - - # Initialize PreProcessor with a custom transform - pre_processor = PreProcessor(transform=self.common_transform) - pre_processor.setup_datamodule_transforms(datamodule=datamodule) - - # Ensure PreProcessor's own transform is used - assert pre_processor.train_transform == self.common_transform - assert pre_processor.val_transform == self.common_transform - assert pre_processor.test_transform == self.common_transform diff --git a/tests/unit/pre_processing/utils/test_transform.py b/tests/unit/pre_processing/utils/test_transform.py index 6974bcdbc8..d4c416e00f 100644 --- a/tests/unit/pre_processing/utils/test_transform.py +++ b/tests/unit/pre_processing/utils/test_transform.py @@ -3,9 +3,6 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import pytest -import torch -from torch.utils.data import DataLoader, TensorDataset from torchvision.transforms.v2 import CenterCrop, Compose, Resize, ToTensor from anomalib.data.transforms import ExportableCenterCrop @@ -13,36 +10,9 @@ convert_center_crop_transform, disable_antialiasing, get_exportable_transform, - set_dataloader_transform, ) -def test_set_dataloader_transform() -> None: - """Test the set_dataloader_transform function.""" - - # Test with single DataLoader - class TransformableDataset(TensorDataset): - def __init__(self, *tensors) -> None: - super().__init__(*tensors) - self.transform = None - - dataset = TransformableDataset(torch.randn(10, 3, 224, 224)) - dataloader = DataLoader(dataset) - transform = ToTensor() - set_dataloader_transform(dataloader, transform) - assert dataloader.dataset.transform == transform - - # Test with sequence of DataLoaders - dataloaders = [DataLoader(TransformableDataset(torch.randn(10, 3, 224, 224))) for _ in range(3)] - set_dataloader_transform(dataloaders, transform) - for dl in dataloaders: - assert dl.dataset.transform == transform - - # Test with unsupported type - with pytest.raises(TypeError): - set_dataloader_transform({"key": "value"}, transform) - - def test_get_exportable_transform() -> None: """Test the get_exportable_transform function.""" # Test with None transform From ec03d3dd117383907929d81eeba4f659c33a8654 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Mon, 16 Dec 2024 23:22:36 +0100 Subject: [PATCH 03/22] update expected config for adapter tests --- tests/integration/tools/upgrade/expected_draem_v1.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/tools/upgrade/expected_draem_v1.yaml b/tests/integration/tools/upgrade/expected_draem_v1.yaml index feb059214d..0e65e8f49b 100644 --- a/tests/integration/tools/upgrade/expected_draem_v1.yaml +++ b/tests/integration/tools/upgrade/expected_draem_v1.yaml @@ -6,6 +6,10 @@ data: train_batch_size: 72 eval_batch_size: 32 num_workers: 8 + train_augmentations: null + val_augmentations: null + test_augmentations: null + augmentations: null test_split_mode: from_dir test_split_ratio: 0.2 val_split_mode: same_as_test From 1d45aa3c1b16e46226cd053543659ce490c50687 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 17 Dec 2024 12:52:24 +0100 Subject: [PATCH 04/22] fix buffer issue --- src/anomalib/utils/visualization/image.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/anomalib/utils/visualization/image.py b/src/anomalib/utils/visualization/image.py index 16b852235f..707827cae7 100644 --- a/src/anomalib/utils/visualization/image.py +++ b/src/anomalib/utils/visualization/image.py @@ -290,7 +290,6 @@ def generate(self) -> np.ndarray: axis.title.set_text(image_dict["title"]) self.figure.canvas.draw() # convert canvas to numpy array to prepare for visualization with opencv - img = np.frombuffer(self.figure.canvas.tostring_rgb(), dtype=np.uint8) - img = img.reshape(self.figure.canvas.get_width_height()[::-1] + (3,)) + img = np.array(self.figure.canvas.buffer_rgba(), dtype=np.uint8)[..., :3] plt.close(self.figure) return img From fe48e6647ae9b9033f4968e91feab7322e3fa272 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 17 Dec 2024 16:48:56 +0100 Subject: [PATCH 05/22] update data notebooks --- notebooks/100_datamodules/101_btech.ipynb | 31 +++------------------- notebooks/100_datamodules/102_mvtec.ipynb | 28 +++---------------- notebooks/100_datamodules/103_folder.ipynb | 28 ++----------------- 3 files changed, 9 insertions(+), 78 deletions(-) diff --git a/notebooks/100_datamodules/101_btech.ipynb b/notebooks/100_datamodules/101_btech.ipynb index cd980fc56e..19ac3277c2 100644 --- a/notebooks/100_datamodules/101_btech.ipynb +++ b/notebooks/100_datamodules/101_btech.ipynb @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -61,18 +61,16 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# flake8: noqa\n", "import numpy as np\n", "from PIL import Image\n", - "from torchvision.transforms.v2 import Resize\n", "from torchvision.transforms.v2.functional import to_pil_image\n", "\n", - "from anomalib.data import BTech, BTechDataset\n", - "from anomalib import TaskType" + "from anomalib.data import BTech, BTechDataset" ] }, { @@ -99,7 +97,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -203,25 +201,6 @@ "BTechDataset??" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can add some transforms that will be applied to the images using torchvision. Let's add a transform that resizes the \n", - "input image to 256x256 pixels." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "image_size = (256, 256)\n", - "transform = Resize(image_size, antialias=True)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -240,7 +219,6 @@ "btech_dataset_train = BTechDataset(\n", " root=dataset_root,\n", " category=\"01\",\n", - " transform=transform,\n", " split=\"train\",\n", ")\n", "print(len(btech_dataset_train))\n", @@ -268,7 +246,6 @@ "btech_dataset_test = BTechDataset(\n", " root=dataset_root,\n", " category=\"01\",\n", - " transform=transform,\n", " split=\"test\",\n", ")\n", "print(len(btech_dataset_test))\n", diff --git a/notebooks/100_datamodules/102_mvtec.ipynb b/notebooks/100_datamodules/102_mvtec.ipynb index cbc62f51dd..573c83f399 100644 --- a/notebooks/100_datamodules/102_mvtec.ipynb +++ b/notebooks/100_datamodules/102_mvtec.ipynb @@ -23,14 +23,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# flake8: noqa\n", "import numpy as np\n", "from PIL import Image\n", - "from torchvision.transforms.v2 import Resize\n", "from torchvision.transforms.v2.functional import to_pil_image\n", "\n", "from anomalib.data import MVTec, MVTecDataset" @@ -48,7 +47,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -76,7 +75,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -180,25 +179,6 @@ "MVTecDataset??" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can add some transforms that will be applied to the images using torchvision. Let's add a transform that resizes the \n", - "input image to 256x256 pixels." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "image_size = (256, 256)\n", - "transform = Resize(image_size, antialias=True)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -217,7 +197,6 @@ "mvtec_dataset_train = MVTecDataset(\n", " root=dataset_root,\n", " category=\"bottle\",\n", - " transform=transform,\n", " split=\"train\",\n", ")\n", "print(len(mvtec_dataset_train))\n", @@ -245,7 +224,6 @@ "mvtec_dataset_test = MVTecDataset(\n", " root=dataset_root,\n", " category=\"bottle\",\n", - " transform=transform,\n", " split=\"test\",\n", ")\n", "print(len(mvtec_dataset_test))\n", diff --git a/notebooks/100_datamodules/103_folder.ipynb b/notebooks/100_datamodules/103_folder.ipynb index e40b68a858..df9154f056 100644 --- a/notebooks/100_datamodules/103_folder.ipynb +++ b/notebooks/100_datamodules/103_folder.ipynb @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -63,14 +63,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# flake8: noqa\n", "import numpy as np\n", "from PIL import Image\n", - "from torchvision.transforms.v2 import Resize\n", "from torchvision.transforms.v2.functional import to_pil_image\n", "\n", "from anomalib.data import Folder, FolderDataset" @@ -173,25 +172,6 @@ "FolderDataset??" ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can add some transforms that will be applied to the images using torchvision. Let's add a transform that resizes the \n", - "input image to 256x256 pixels." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "image_size = (256, 256)\n", - "transform = Resize(image_size, antialias=True)" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -211,7 +191,6 @@ " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"train\",\n", - " transform=transform,\n", ")\n", "print(len(folder_dataset_train))\n", "sample = folder_dataset_train[0]\n", @@ -241,7 +220,6 @@ " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"test\",\n", - " transform=transform,\n", ")\n", "print(len(folder_dataset_test))\n", "sample = folder_dataset_test[0]\n", @@ -270,7 +248,6 @@ " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"train\",\n", - " transform=transform,\n", " mask_dir=dataset_root / \"mask\" / \"crack\",\n", ")\n", "print(len(folder_dataset_segmentation_train))\n", @@ -290,7 +267,6 @@ " normal_dir=dataset_root / \"good\",\n", " abnormal_dir=dataset_root / \"crack\",\n", " split=\"test\",\n", - " transform=transform,\n", " mask_dir=dataset_root / \"mask\" / \"crack\",\n", ")\n", "print(len(folder_dataset_segmentation_test))\n", From 151a17943df89a502ad498874decfeb91698dc92 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Tue, 17 Dec 2024 17:10:14 +0100 Subject: [PATCH 06/22] reduce num workers in MLFlow notebook --- notebooks/600_loggers/601_mlflow_logging.ipynb | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/notebooks/600_loggers/601_mlflow_logging.ipynb b/notebooks/600_loggers/601_mlflow_logging.ipynb index c3cbe0fdb5..35aaced36f 100644 --- a/notebooks/600_loggers/601_mlflow_logging.ipynb +++ b/notebooks/600_loggers/601_mlflow_logging.ipynb @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ " category=\"bottle\",\n", " train_batch_size=32,\n", " eval_batch_size=32,\n", - " num_workers=24,\n", + " num_workers=8,\n", ")" ] }, @@ -250,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -273,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -291,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ From 544ac7eb79b29f1c82b81dda4776a934c8555e04 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Wed, 18 Dec 2024 17:28:37 +0100 Subject: [PATCH 07/22] Revert "reduce num workers in MLFlow notebook" This reverts commit 151a17943df89a502ad498874decfeb91698dc92. --- notebooks/600_loggers/601_mlflow_logging.ipynb | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/notebooks/600_loggers/601_mlflow_logging.ipynb b/notebooks/600_loggers/601_mlflow_logging.ipynb index 35aaced36f..c3cbe0fdb5 100644 --- a/notebooks/600_loggers/601_mlflow_logging.ipynb +++ b/notebooks/600_loggers/601_mlflow_logging.ipynb @@ -126,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -189,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -198,7 +198,7 @@ " category=\"bottle\",\n", " train_batch_size=32,\n", " eval_batch_size=32,\n", - " num_workers=8,\n", + " num_workers=24,\n", ")" ] }, @@ -250,7 +250,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -273,7 +273,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -291,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ From f15f0f775350cc3ec0c3db59db6027be9169fa10 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 14:43:56 +0100 Subject: [PATCH 08/22] match resize between augmentations and model transforms --- src/anomalib/data/datamodules/base/image.py | 112 ++++++++++++++++++-- 1 file changed, 103 insertions(+), 9 deletions(-) diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index b45160322c..783e151ddd 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -7,22 +7,22 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from lightning.pytorch import LightningDataModule from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data.dataloader import DataLoader -from torchvision.transforms.v2 import Transform +from torchvision.transforms.v2 import Compose, Resize, Transform from anomalib import TaskType +from anomalib.data.datasets.base.image import AnomalibDataset from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label from anomalib.data.utils.synthetic import SyntheticAnomalyDataset if TYPE_CHECKING: from pandas import DataFrame - from anomalib.data.datasets.base.image import AnomalibDataset logger = logging.getLogger(__name__) @@ -113,12 +113,106 @@ def setup(self, stage: str | None = None) -> None: # only set the flag if the stage is a TrainerFn, which means the setup has been called from a trainer self._is_setup = True - if hasattr(self, "train_data"): - self.train_data.augmentations = self.train_augmentations - if hasattr(self, "val_data"): - self.val_data.augmentations = self.val_augmentations - if hasattr(self, "test_data"): - self.test_data.augmentations = self.test_augmentations + self._update_augmentations() + + def _update_augmentations(self) -> None: + """Update the augmentations for each subset.""" + for subset_name in ["train", "val", "test"]: + subset = getattr(self, f"{subset_name}_data", None) + augmentations = getattr(self, f"{subset_name}_augmentations", None) + model_transform = self.get_nested_attr(self, f"trainer.model.pre_processor.{subset_name}_transform") + if subset and augmentations: + self._update_subset_augmentations(subset, augmentations, model_transform) + + def _update_subset_augmentations( + self, + dataset: AnomalibDataset, + augmentations: Transform, + model_transform: Transform, + ) -> None: + """Update the augmentations of the dataset. + + This method passes the user-specified augmentations to a dataset subset. If the model transforms contain + a Resize transform, it will be appended to the augmentations. This will ensure that resizing takes place + before collating, which reduces the usage of shared memory by the Dataloader workers. + + Args: + dataset (AnomalibDataset): Dataset to update. + augmentations (Transform): Augmentations to apply to the dataset. + model_transform (Transform): Transform object from the model PreProcessor. + """ + model_resizes = self.get_resize_transforms(model_transform) + + if model_resizes: + model_resize = model_resizes[0] + for aug_resize in self.get_resize_transforms(augmentations): # warn user if resizes inconsistent + if model_resize.size != aug_resize.size: + msg = f"Conflicting resize shapes found between augmentations and model transforms. You are using \ + a Resize transform in your input data augmentations. Please be aware that the model also \ + applies a Resize transform with a different output size. The final effective input size as \ + seen by the model will be determined by the model transforms, not the augmentations. To change \ + the effective input size, please change the model transforms in the PreProcessor module. \ + Augmentations: {aug_resize.size}, Model transforms: {model_transform.size}" + logger.warning(msg) + if model_resize.interpolation != aug_resize.interpolation: + msg = f"Conflicting interpolation method found between augmentations and model transforms. You are \ + using a Resize transform in your input data augmentations. Please be aware that the model also \ + applies a Resize transform with a different interpolation method. Using multiple interpolation \ + methods can lead to unexpected behaviour, so it is recommended to use the same interpolation \ + method between augmentations and model transforms. Augmentations: {aug_resize.interpolation}, \ + Model transforms: {model_resize.interpolation}" + logger.warning(msg) + if model_resize.antialias != aug_resize.antialias: + msg = f"Conflicting antialiasing setting found between augmentations and model transforms. You are \ + using a Resize transform in your input data augmentations. Please be aware that the model also \ + applies a Resize transform with a different antialising setting. Using conflicting \ + antialiasing settings can lead to unexpected behaviour, so it is recommended to use the same \ + antialiasing setting between augmentations and model transforms. Augmentations: \ + antialias={aug_resize.antialias}, Model transforms: antialias={model_resize.antialias}" + + # append model resize to augmentations + if isinstance(augmentations, Compose): + augmentations = Compose([*augmentations.transforms, model_resize]) + elif isinstance(augmentations, Transform): + augmentations = Compose([augmentations, model_resize]) + elif augmentations is None: + augmentations = model_resize + + dataset.augmentations = augmentations + + @staticmethod + def get_resize_transforms(transform: Transform | None) -> list[Resize]: + """Get a list of all the resize transforms present in the provided Transform. + + Args: + transform (Transform): Torchvision Transform instance. + + Returns: + List[Resize]: List of Resize transform instances. + """ + if isinstance(transform, Resize): + return [transform] + if isinstance(transform, Compose): + return [transform for transform in transform.transforms if isinstance(transform, Resize)] + return [] + + @staticmethod + def get_nested_attr(obj: Any, attr_path: str, default: Any | None = None) -> Any: # noqa: ANN401 + """Safely retrieves a nested attribute from an object. + + Args: + obj: The object to retrieve the attribute from. + attr_path: A dot-separated string representing the attribute path. + default: The default value to return if any attribute in the path is missing. + + Returns: + The value of the nested attribute, or `default` if any attribute in the path is missing. + """ + for attr in attr_path.split("."): + obj = getattr(obj, attr, default) + if obj is default: + return default + return obj @abstractmethod def _setup(self, _stage: str | None = None) -> None: From 8ae05a3e4df720331f1cd1e16dbc61fcc356893d Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 17:14:55 +0100 Subject: [PATCH 09/22] remove subset-specific transforms in preprocessor --- configs/model/cfa.yaml | 3 -- configs/model/cflow.yaml | 4 -- configs/model/csflow.yaml | 4 -- configs/model/draem.yaml | 4 -- configs/model/dsr.yaml | 4 -- configs/model/efficient_ad.yaml | 4 -- configs/model/fastflow.yaml | 4 -- configs/model/padim.yaml | 3 -- configs/model/reverse_distillation.yaml | 4 -- configs/model/stfpm.yaml | 4 -- configs/model/uflow.yaml | 4 -- src/anomalib/data/datamodules/base/image.py | 2 +- .../models/components/base/anomalib_module.py | 6 +-- .../models/image/winclip/lightning_model.py | 2 +- src/anomalib/pre_processing/pre_processing.py | 35 ++++--------- .../pre_processing/test_pre_processing.py | 51 ------------------- 16 files changed, 13 insertions(+), 125 deletions(-) diff --git a/configs/model/cfa.yaml b/configs/model/cfa.yaml index 457a7f5387..1f3ad7ec72 100644 --- a/configs/model/cfa.yaml +++ b/configs/model/cfa.yaml @@ -8,9 +8,6 @@ model: num_hard_negative_features: 3 radius: 1.0e-05 -metrics: - pixel: AUROC - trainer: max_epochs: 30 callbacks: diff --git a/configs/model/cflow.yaml b/configs/model/cflow.yaml index dc134278ce..3d7e53917e 100644 --- a/configs/model/cflow.yaml +++ b/configs/model/cflow.yaml @@ -15,10 +15,6 @@ model: permute_soft: false lr: 0.0001 -metrics: - pixel: - - AUROC - trainer: max_epochs: 50 callbacks: diff --git a/configs/model/csflow.yaml b/configs/model/csflow.yaml index cece0b379c..796490fe97 100644 --- a/configs/model/csflow.yaml +++ b/configs/model/csflow.yaml @@ -6,10 +6,6 @@ model: clamp: 3 num_channels: 3 -metrics: - pixel: - - AUROC - trainer: max_epochs: 240 callbacks: diff --git a/configs/model/draem.yaml b/configs/model/draem.yaml index 17d85220e4..04914e4282 100644 --- a/configs/model/draem.yaml +++ b/configs/model/draem.yaml @@ -6,10 +6,6 @@ model: sspcab_lambda: 0.1 anomaly_source_path: null -metrics: - pixel: - - AUROC - trainer: max_epochs: 700 callbacks: diff --git a/configs/model/dsr.yaml b/configs/model/dsr.yaml index 859438418a..7a2f84997d 100644 --- a/configs/model/dsr.yaml +++ b/configs/model/dsr.yaml @@ -4,10 +4,6 @@ model: latent_anomaly_strength: 0.2 upsampling_train_ratio: 0.7 -metrics: - pixel: - - AUROC - # PL Trainer Args. Don't add extra parameter here. trainer: max_epochs: 700 diff --git a/configs/model/efficient_ad.yaml b/configs/model/efficient_ad.yaml index 1d7f70b7eb..9e64851e5f 100644 --- a/configs/model/efficient_ad.yaml +++ b/configs/model/efficient_ad.yaml @@ -8,10 +8,6 @@ model: padding: false pad_maps: true -metrics: - pixel: - - AUROC - trainer: max_epochs: 1000 max_steps: 70000 diff --git a/configs/model/fastflow.yaml b/configs/model/fastflow.yaml index 13cdd69a3e..8bcde42c78 100644 --- a/configs/model/fastflow.yaml +++ b/configs/model/fastflow.yaml @@ -7,10 +7,6 @@ model: conv3x3_only: false hidden_ratio: 1.0 -metrics: - pixel: - - AUROC - trainer: max_epochs: 500 callbacks: diff --git a/configs/model/padim.yaml b/configs/model/padim.yaml index 3787897889..daeb806b86 100644 --- a/configs/model/padim.yaml +++ b/configs/model/padim.yaml @@ -8,6 +8,3 @@ model: backbone: resnet18 pre_trained: true n_features: null - -metrics: - pixel: AUROC diff --git a/configs/model/reverse_distillation.yaml b/configs/model/reverse_distillation.yaml index 97184b0aa0..523b303681 100644 --- a/configs/model/reverse_distillation.yaml +++ b/configs/model/reverse_distillation.yaml @@ -9,10 +9,6 @@ model: anomaly_map_mode: ADD pre_trained: true -metrics: - pixel: - - AUROC - trainer: callbacks: - class_path: lightning.pytorch.callbacks.EarlyStopping diff --git a/configs/model/stfpm.yaml b/configs/model/stfpm.yaml index c5e783baaa..40db04aec2 100644 --- a/configs/model/stfpm.yaml +++ b/configs/model/stfpm.yaml @@ -7,10 +7,6 @@ model: - layer2 - layer3 -metrics: - pixel: - - AUROC - trainer: max_epochs: 100 callbacks: diff --git a/configs/model/uflow.yaml b/configs/model/uflow.yaml index 6b6ccd81eb..00329b9b68 100644 --- a/configs/model/uflow.yaml +++ b/configs/model/uflow.yaml @@ -7,10 +7,6 @@ model: affine_subnet_channels_ratio: 1.0 backbone: mcait # official: mcait, other extractors tested: resnet18, wide_resnet50_2. Could use others... -metrics: - pixel: - - AUROC - # PL Trainer Args. Don't add extra parameter here. trainer: max_epochs: 200 diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 783e151ddd..80443e3d9e 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -120,7 +120,7 @@ def _update_augmentations(self) -> None: for subset_name in ["train", "val", "test"]: subset = getattr(self, f"{subset_name}_data", None) augmentations = getattr(self, f"{subset_name}_augmentations", None) - model_transform = self.get_nested_attr(self, f"trainer.model.pre_processor.{subset_name}_transform") + model_transform = self.get_nested_attr(self, "trainer.model.pre_processor.transform") if subset and augmentations: self._update_subset_augmentations(subset, augmentations, model_transform) diff --git a/src/anomalib/models/components/base/anomalib_module.py b/src/anomalib/models/components/base/anomalib_module.py index 3fd5557032..70593e80f8 100644 --- a/src/anomalib/models/components/base/anomalib_module.py +++ b/src/anomalib/models/components/base/anomalib_module.py @@ -22,7 +22,6 @@ from anomalib.data import Batch, InferenceBatch from anomalib.metrics import AUROC, F1Score from anomalib.metrics.evaluator import Evaluator -from anomalib.metrics.threshold import Threshold from anomalib.post_processing import OneClassPostProcessor, PostProcessor from anomalib.pre_processing import PreProcessor from anomalib.visualization import ImageVisualizer, Visualizer @@ -368,7 +367,7 @@ def input_size(self) -> tuple[int, int] | None: The effective input size is the size of the input tensor after the transform has been applied. If the transform is not set, or if the transform does not change the shape of the input tensor, this method will return None. """ - transform = self.pre_processor.predict_transform if self.pre_processor else None + transform = self.pre_processor.transform if self.pre_processor else None if transform is None: return None dummy_input = torch.zeros(1, 3, 1, 1) @@ -418,9 +417,6 @@ def from_config( help="Path to a configuration file in json or yaml format.", ) model_parser.add_subclass_arguments(AnomalibModule, "model", required=False, fail_untyped=False) - model_parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"]) - model_parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False) - model_parser.add_argument("--metrics.threshold", type=Threshold | str, default="F1AdaptiveThreshold") model_parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True) args = ["--config", str(config_path)] for key, value in kwargs.items(): diff --git a/src/anomalib/models/image/winclip/lightning_model.py b/src/anomalib/models/image/winclip/lightning_model.py index 23a7cf23a1..31b41ef5e3 100644 --- a/src/anomalib/models/image/winclip/lightning_model.py +++ b/src/anomalib/models/image/winclip/lightning_model.py @@ -199,7 +199,7 @@ def configure_pre_processor(cls, image_size: tuple[int, int] | None = None) -> P Resize((240, 240), antialias=True, interpolation=InterpolationMode.BICUBIC), Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)), ]) - return PreProcessor(val_transform=transform, test_transform=transform) + return PreProcessor(transform=transform) @staticmethod def configure_post_processor() -> OneClassPostProcessor: diff --git a/src/anomalib/pre_processing/pre_processing.py b/src/anomalib/pre_processing/pre_processing.py index 812540860b..74ff51df4e 100644 --- a/src/anomalib/pre_processing/pre_processing.py +++ b/src/anomalib/pre_processing/pre_processing.py @@ -77,33 +77,18 @@ class PreProcessor(nn.Module, Callback): def __init__( self, - train_transform: Transform | None = None, - val_transform: Transform | None = None, - test_transform: Transform | None = None, transform: Transform | None = None, ) -> None: super().__init__() - if transform and any([train_transform, val_transform, test_transform]): - msg = ( - "`transforms` cannot be used together with `train_transform`, `val_transform`, `test_transform`.\n" - "If you want to apply the same transform to the training, validation and test data, " - "use only `transforms`. \n" - "Otherwise, specify transforms for training, validation and test individually." - ) - raise ValueError(msg) - - self.train_transform = train_transform or transform - self.val_transform = val_transform or transform - self.test_transform = test_transform or transform - self.predict_transform = self.test_transform - self.export_transform = get_exportable_transform(self.test_transform) + self.transform = transform + self.export_transform = get_exportable_transform(self.transform) def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Batch, batch_idx: int) -> None: """Apply transforms to the batch of tensors during training.""" del trainer, pl_module, batch_idx # Unused - if self.train_transform: - batch.image, batch.gt_mask = self.train_transform(batch.image, batch.gt_mask) + if self.transform: + batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask) def on_validation_batch_start( self, @@ -114,8 +99,8 @@ def on_validation_batch_start( ) -> None: """Apply transforms to the batch of tensors during validation.""" del trainer, pl_module, batch_idx # Unused - if self.val_transform: - batch.image, batch.gt_mask = self.val_transform(batch.image, batch.gt_mask) + if self.transform: + batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask) def on_test_batch_start( self, @@ -127,8 +112,8 @@ def on_test_batch_start( ) -> None: """Apply transforms to the batch of tensors during testing.""" del trainer, pl_module, batch_idx, dataloader_idx # Unused - if self.test_transform: - batch.image, batch.gt_mask = self.test_transform(batch.image, batch.gt_mask) + if self.transform: + batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask) def on_predict_batch_start( self, @@ -140,8 +125,8 @@ def on_predict_batch_start( ) -> None: """Apply transforms to the batch of tensors during prediction.""" del trainer, pl_module, batch_idx, dataloader_idx # Unused - if self.predict_transform: - batch.image, batch.gt_mask = self.predict_transform(batch.image, batch.gt_mask) + if self.transform: + batch.image, batch.gt_mask = self.transform(batch.image, batch.gt_mask) def forward(self, batch: torch.Tensor) -> torch.Tensor: """Apply transforms to the batch of tensors for inference. diff --git a/tests/unit/pre_processing/test_pre_processing.py b/tests/unit/pre_processing/test_pre_processing.py index 8962a328ef..dbc677f66a 100644 --- a/tests/unit/pre_processing/test_pre_processing.py +++ b/tests/unit/pre_processing/test_pre_processing.py @@ -23,26 +23,6 @@ def setup(self) -> None: self.dummy_batch = ImageBatch(image=image, gt_mask=gt_mask) self.common_transform = Compose([Resize((224, 224)), ToImage(), ToDtype(torch.float32, scale=True)]) - def test_init(self) -> None: - """Test the initialization of the PreProcessor class.""" - # Test with stage-specific transforms - train_transform = Compose([Resize((224, 224)), ToImage(), ToDtype(torch.float32, scale=True)]) - val_transform = Compose([Resize((256, 256)), ToImage(), ToDtype(torch.float32, scale=True)]) - pre_processor = PreProcessor(train_transform=train_transform, val_transform=val_transform) - assert pre_processor.train_transform == train_transform - assert pre_processor.val_transform == val_transform - assert pre_processor.test_transform is None - - # Test with single transform for all stages - pre_processor = PreProcessor(transform=self.common_transform) - assert pre_processor.train_transform == self.common_transform - assert pre_processor.val_transform == self.common_transform - assert pre_processor.test_transform == self.common_transform - - # Test error case: both transform and stage-specific transform - with pytest.raises(ValueError, match="`transforms` cannot be used together with"): - PreProcessor(transform=self.common_transform, train_transform=train_transform) - def test_forward(self) -> None: """Test the forward method of the PreProcessor class.""" pre_processor = PreProcessor(transform=self.common_transform) @@ -56,34 +36,3 @@ def test_no_transform(self) -> None: processed_batch = pre_processor(self.dummy_batch.image) assert isinstance(processed_batch, torch.Tensor) assert processed_batch.shape == (1, 3, 256, 256) - - @staticmethod - def test_different_stage_transforms() -> None: - """Test different stage transforms.""" - train_transform = Compose([Resize((224, 224)), ToImage(), ToDtype(torch.float32, scale=True)]) - val_transform = Compose([Resize((256, 256)), ToImage(), ToDtype(torch.float32, scale=True)]) - test_transform = Compose([Resize((288, 288)), ToImage(), ToDtype(torch.float32, scale=True)]) - - pre_processor = PreProcessor( - train_transform=train_transform, - val_transform=val_transform, - test_transform=test_transform, - ) - - # Test train transform - test_batch = ImageBatch(image=Image(torch.rand(3, 256, 256)), gt_mask=Mask(torch.zeros(256, 256))) - processed_batch = pre_processor.train_transform(test_batch.image) - assert isinstance(processed_batch, torch.Tensor) - assert processed_batch.shape == (1, 3, 224, 224) - - # Test validation transform - test_batch = ImageBatch(image=Image(torch.rand(3, 256, 256)), gt_mask=Mask(torch.zeros(256, 256))) - processed_batch = pre_processor.val_transform(test_batch.image) - assert isinstance(processed_batch, torch.Tensor) - assert processed_batch.shape == (1, 3, 256, 256) - - # Test test transform - test_batch = ImageBatch(image=Image(torch.rand(3, 256, 256)), gt_mask=Mask(torch.zeros(256, 256))) - processed_batch = pre_processor.test_transform(test_batch.image) - assert isinstance(processed_batch, torch.Tensor) - assert processed_batch.shape == (1, 3, 288, 288) From 4849be2b56a8d5ab5e33a7f9a26f444f307f331d Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 17:31:53 +0100 Subject: [PATCH 10/22] move nested attr helper to utils --- src/anomalib/data/datamodules/base/image.py | 23 +++----------------- src/anomalib/utils/__init__.py | 4 ++++ src/anomalib/utils/attrs.py | 24 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 20 deletions(-) create mode 100644 src/anomalib/utils/attrs.py diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 80443e3d9e..a4fe294c5c 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -7,7 +7,7 @@ import logging from abc import ABC, abstractmethod from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from lightning.pytorch import LightningDataModule from lightning.pytorch.trainer.states import TrainerFn @@ -19,6 +19,7 @@ from anomalib.data.datasets.base.image import AnomalibDataset from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label from anomalib.data.utils.synthetic import SyntheticAnomalyDataset +from anomalib.utils.attrs import get_nested_attr if TYPE_CHECKING: from pandas import DataFrame @@ -120,7 +121,7 @@ def _update_augmentations(self) -> None: for subset_name in ["train", "val", "test"]: subset = getattr(self, f"{subset_name}_data", None) augmentations = getattr(self, f"{subset_name}_augmentations", None) - model_transform = self.get_nested_attr(self, "trainer.model.pre_processor.transform") + model_transform = get_nested_attr(self, "trainer.model.pre_processor.transform") if subset and augmentations: self._update_subset_augmentations(subset, augmentations, model_transform) @@ -196,24 +197,6 @@ def get_resize_transforms(transform: Transform | None) -> list[Resize]: return [transform for transform in transform.transforms if isinstance(transform, Resize)] return [] - @staticmethod - def get_nested_attr(obj: Any, attr_path: str, default: Any | None = None) -> Any: # noqa: ANN401 - """Safely retrieves a nested attribute from an object. - - Args: - obj: The object to retrieve the attribute from. - attr_path: A dot-separated string representing the attribute path. - default: The default value to return if any attribute in the path is missing. - - Returns: - The value of the nested attribute, or `default` if any attribute in the path is missing. - """ - for attr in attr_path.split("."): - obj = getattr(obj, attr, default) - if obj is default: - return default - return obj - @abstractmethod def _setup(self, _stage: str | None = None) -> None: """Set up the datasets and perform dynamic subset splitting. diff --git a/src/anomalib/utils/__init__.py b/src/anomalib/utils/__init__.py index 8ffe7654fe..e46b3da31b 100644 --- a/src/anomalib/utils/__init__.py +++ b/src/anomalib/utils/__init__.py @@ -2,3 +2,7 @@ # Copyright (C) 2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 + +from .attrs import get_nested_attr + +__all__ = ["get_nested_attr"] diff --git a/src/anomalib/utils/attrs.py b/src/anomalib/utils/attrs.py new file mode 100644 index 0000000000..35f3971485 --- /dev/null +++ b/src/anomalib/utils/attrs.py @@ -0,0 +1,24 @@ +"""Utility functions for working with attributes.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + + +def get_nested_attr(obj: Any, attr_path: str, default: Any | None = None) -> Any: # noqa: ANN401 + """Safely retrieves a nested attribute from an object. + + Args: + obj: The object to retrieve the attribute from. + attr_path: A dot-separated string representing the attribute path. + default: The default value to return if any attribute in the path is missing. + + Returns: + The value of the nested attribute, or `default` if any attribute in the path is missing. + """ + for attr in attr_path.split("."): + obj = getattr(obj, attr, default) + if obj is default: + return default + return obj From f8882a11a4e598a88f9b83f51ce058f626865a62 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 17:46:46 +0100 Subject: [PATCH 11/22] move transform retrieve function to transform utils --- src/anomalib/data/datamodules/base/image.py | 19 ++--------------- src/anomalib/data/transforms/utils.py | 23 +++++++++++++++++++++ 2 files changed, 25 insertions(+), 17 deletions(-) create mode 100644 src/anomalib/data/transforms/utils.py diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index a4fe294c5c..09e482e2f4 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -17,6 +17,7 @@ from anomalib import TaskType from anomalib.data.datasets.base.image import AnomalibDataset +from anomalib.data.transforms.utils import get_transforms_of_type from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label from anomalib.data.utils.synthetic import SyntheticAnomalyDataset from anomalib.utils.attrs import get_nested_attr @@ -142,7 +143,7 @@ def _update_subset_augmentations( augmentations (Transform): Augmentations to apply to the dataset. model_transform (Transform): Transform object from the model PreProcessor. """ - model_resizes = self.get_resize_transforms(model_transform) + model_resizes = get_transforms_of_type(model_transform, Resize) if model_resizes: model_resize = model_resizes[0] @@ -181,22 +182,6 @@ def _update_subset_augmentations( dataset.augmentations = augmentations - @staticmethod - def get_resize_transforms(transform: Transform | None) -> list[Resize]: - """Get a list of all the resize transforms present in the provided Transform. - - Args: - transform (Transform): Torchvision Transform instance. - - Returns: - List[Resize]: List of Resize transform instances. - """ - if isinstance(transform, Resize): - return [transform] - if isinstance(transform, Compose): - return [transform for transform in transform.transforms if isinstance(transform, Resize)] - return [] - @abstractmethod def _setup(self, _stage: str | None = None) -> None: """Set up the datasets and perform dynamic subset splitting. diff --git a/src/anomalib/data/transforms/utils.py b/src/anomalib/data/transforms/utils.py new file mode 100644 index 0000000000..b1e21687a5 --- /dev/null +++ b/src/anomalib/data/transforms/utils.py @@ -0,0 +1,23 @@ +"""Utility functions for working with Torchvision Transforms.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from torchvision.transforms.v2 import Compose, Transform + + +def get_transforms_of_type(input_transform: Transform | None, transform_type: type[Transform]) -> list[type[Transform]]: + """Retrieves all transforms of a given type from a transform or transform composition. + + Args: + input_transform (Transform): Torchvision Transform instance. + transform_type (Type[Transform]): Type of transform to retrieve. + + Returns: + List[Transform]: List of Resize transform instances. + """ + if isinstance(input_transform, transform_type): + return [input_transform] + if isinstance(input_transform, Compose): + return [transform for transform in input_transform.transforms if isinstance(transform, transform_type)] + return [] From ceabaf730182f537067aa88b228c7b969f5af05f Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 17:50:29 +0100 Subject: [PATCH 12/22] update efficientad transform validation --- .../models/image/efficient_ad/lightning_model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index aa99d6a439..6084691ff1 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -19,6 +19,7 @@ from anomalib import LearningType from anomalib.data import Batch +from anomalib.data.transforms.utils import get_transforms_of_type from anomalib.data.utils import DownloadInfo, download_and_extract from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule @@ -267,11 +268,9 @@ def on_train_start(self) -> None: msg = "train_batch_size for EfficientAd should be 1." raise ValueError(msg) - if self.pre_processor and self.pre_processor.train_transform: - transforms = self.pre_processor.train_transform.transforms - if transforms and any(isinstance(transform, Normalize) for transform in transforms): - msg = "Transforms for EfficientAd should not contain Normalize." - raise ValueError(msg) + if self.pre_processor and get_transforms_of_type(self.pre_processor.transform, Normalize): + msg = "Transforms for EfficientAd should not contain Normalize." + raise ValueError(msg) sample = next(iter(self.trainer.train_dataloader)) image_size = sample.image.shape[-2:] From 133822beecaab58b0a7fc112adc8dcf9b429150c Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 18:07:11 +0100 Subject: [PATCH 13/22] formatting --- src/anomalib/pre_processing/pre_processing.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/anomalib/pre_processing/pre_processing.py b/src/anomalib/pre_processing/pre_processing.py index 74ff51df4e..f6b42a82d2 100644 --- a/src/anomalib/pre_processing/pre_processing.py +++ b/src/anomalib/pre_processing/pre_processing.py @@ -84,7 +84,13 @@ def __init__( self.transform = transform self.export_transform = get_exportable_transform(self.transform) - def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Batch, batch_idx: int) -> None: + def on_train_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch: Batch, + batch_idx: int, + ) -> None: """Apply transforms to the batch of tensors during training.""" del trainer, pl_module, batch_idx # Unused if self.transform: From a6ecb3af42382a25f1984e69bdee7bd55a1394d8 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 18:17:22 +0100 Subject: [PATCH 14/22] update preprocessor docstring --- src/anomalib/pre_processing/pre_processing.py | 55 +++++-------------- 1 file changed, 15 insertions(+), 40 deletions(-) diff --git a/src/anomalib/pre_processing/pre_processing.py b/src/anomalib/pre_processing/pre_processing.py index f6b42a82d2..37a6e9ef35 100644 --- a/src/anomalib/pre_processing/pre_processing.py +++ b/src/anomalib/pre_processing/pre_processing.py @@ -19,60 +19,35 @@ class PreProcessor(nn.Module, Callback): """Anomalib pre-processor. This class serves as both a PyTorch module and a Lightning callback, handling - the application of transforms to data batches during different stages of - training, validation, testing, and prediction. + the application of transforms to data batches as a pre-processing step. Args: - train_transform (Transform | None): Transform to apply during training. - val_transform (Transform | None): Transform to apply during validation. - test_transform (Transform | None): Transform to apply during testing. - transform (Transform | None): General transform to apply if stage-specific - transforms are not provided. - - Raises: - ValueError: If both `transform` and any of the stage-specific transforms - are provided simultaneously. - - Notes: - If only `transform` is provided, it will be used for all stages (train, val, test). - - Priority of transforms: - 1. Explicitly set PreProcessor transforms (highest priority) - 2. Datamodule transforms (if PreProcessor has no transforms) - 3. Dataloader transforms (if neither PreProcessor nor datamodule have transforms) - 4. Default transforms (lowest priority) + transform (Transform | None): Transform to apply to the data before passing it to the model. Examples: >>> from torchvision.transforms.v2 import Compose, Resize, ToTensor >>> from anomalib.pre_processing import PreProcessor - >>> # Define transforms - >>> train_transform = Compose([Resize((224, 224)), ToTensor()]) - >>> val_transform = Compose([Resize((256, 256)), CenterCrop((224, 224)), ToTensor()]) - - >>> # Create PreProcessor with stage-specific transforms - >>> pre_processor = PreProcessor( - ... train_transform=train_transform, - ... val_transform=val_transform - ... ) + >>> # Define a custom set of transforms + >>> transform = Compose([Resize((224, 224)), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) - >>> # Create PreProcessor with a single transform for all stages - >>> common_transform = Compose([Resize((224, 224)), ToTensor()]) - >>> pre_processor_common = PreProcessor(transform=common_transform) + >>> # Pass the custom set of transforms to a model + >>> pre_processor = PreProcessor(transform=transform) + >>> model = MyModel(pre_processor=pre_processor) - >>> # Use in a Lightning module + >>> # Advanced use: configure the default pre-processing behaviour of a Lightning module >>> class MyModel(LightningModule): ... def __init__(self): ... super().__init__() - ... self.pre_processor = PreProcessor(...) + ... ... ... - ... def configure_callbacks(self): - ... return [self.pre_processor] + ... def configure_pre_processor(self): + ... transform = Compose([ + ... Resize((224, 224)), + ... Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ... ]) + ... return PreProcessor(transform) ... - ... def training_step(self, batch, batch_idx): - ... # The pre_processor will automatically apply the correct transform - ... processed_batch = self.pre_processor(batch) - ... # Rest of the training step """ def __init__( From 11edd815006b9d9f7c5c3c4288389e267713b12f Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 18:46:02 +0100 Subject: [PATCH 15/22] fix data notebook --- notebooks/200_models/201_fastflow.ipynb | 71 ++++++++----------------- 1 file changed, 22 insertions(+), 49 deletions(-) diff --git a/notebooks/200_models/201_fastflow.ipynb b/notebooks/200_models/201_fastflow.ipynb index 2e5872db60..dbace61ec9 100644 --- a/notebooks/200_models/201_fastflow.ipynb +++ b/notebooks/200_models/201_fastflow.ipynb @@ -35,7 +35,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -73,9 +73,8 @@ "from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint\n", "from matplotlib import pyplot as plt\n", "from PIL import Image\n", - "from torch.utils.data import DataLoader\n", "\n", - "from anomalib.data import MVTec, PredictDataset\n", + "from anomalib.data import MVTec\n", "from anomalib.engine import Engine\n", "from anomalib.models import Fastflow\n", "from anomalib.utils.post_processing import superimpose_anomaly_map" @@ -97,7 +96,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "pycharm": { "name": "#%%\n" @@ -170,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { "pycharm": { "name": "#%%\n" @@ -209,7 +208,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "pycharm": { "name": "#%%\n" @@ -292,35 +291,7 @@ "source": [ "## Inference\n", "\n", - "Since we have a trained model, we could infer the model on an individual image or folder of images. Anomalib has an `PredictDataset` to let you create an inference dataset. So let's try it.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, - "outputs": [], - "source": [ - "pre_processor = Fastflow.configure_pre_processor()\n", - "transform = pre_processor.predict_transform\n", - "inference_dataset = PredictDataset(path=dataset_root / \"bottle/test/broken_large/000.png\", transform=transform)\n", - "inference_dataloader = DataLoader(dataset=inference_dataset, collate_fn=inference_dataset.collate_fn)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, - "source": [ - "We could utilize `Trainer`'s `predict` method to infer, and get the outputs to visualize\n" + "Since we have a trained model, we could infer the model on an individual image or folder of images. To run inferende on an image (or a folder of images!), we can simply pass the path to the engine's `predict` method.\n" ] }, { @@ -333,7 +304,9 @@ }, "outputs": [], "source": [ - "predictions = engine.predict(model=model, dataloaders=inference_dataloader)[0]" + "data_path = dataset_root / \"bottle/test/broken_large/000.png\"\n", + "predictions = engine.predict(model=model, data_path=data_path)\n", + "prediction = predictions[0] # Get the first and only prediction" ] }, { @@ -345,7 +318,7 @@ } }, "source": [ - "`predictions` contain image, anomaly maps, predicted scores, labels and masks. These are all stored in a dictionary. We could check this by printing the `prediction` keys.\n" + "`prediction` contains image, anomaly maps, predicted scores, labels and masks. These are all stored in a dictionary. We could check this by printing the `prediction` keys.\n" ] }, { @@ -359,9 +332,9 @@ "outputs": [], "source": [ "print(\n", - " f\"Image Shape: {predictions.image.shape},\\n\"\n", - " f\"Anomaly Map Shape: {predictions.anomaly_map.shape}, \\n\"\n", - " f\"Predicted Mask Shape: {predictions.pred_mask.shape}\",\n", + " f\"Image Shape: {prediction.image.shape},\\n\"\n", + " f\"Anomaly Map Shape: {prediction.anomaly_map.shape}, \\n\"\n", + " f\"Predicted Mask Shape: {prediction.pred_mask.shape}\",\n", ")" ] }, @@ -393,17 +366,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Let's first show the input image. To do so, we will use `image_path` key from the `predictions` dictionary, and read the image from path. Note that `predictions` dictionary already contains `image`. However, this is the normalized image with pixel values between 0 and 1. We will use the original image to visualize the input image." + "Let's first show the input image. To do so, we will use `image_path` key from the `prediction` dictionary, and read the image from path. Note that `predictions` dictionary already contains `image`. However, this is the normalized image with pixel values between 0 and 1. We will use the original image to visualize the input image." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ - "image_path = predictions.image_path[0]\n", - "image_size = predictions.image.shape[-2:]\n", + "image_path = prediction.image_path[0]\n", + "image_size = prediction.image.shape[-2:]\n", "image = np.array(Image.open(image_path).resize(image_size))" ] }, @@ -429,7 +402,7 @@ }, "outputs": [], "source": [ - "anomaly_map = predictions.anomaly_map[0]\n", + "anomaly_map = prediction.anomaly_map[0]\n", "anomaly_map = anomaly_map.cpu().numpy().squeeze()\n", "plt.imshow(anomaly_map)" ] @@ -469,7 +442,7 @@ } }, "source": [ - "`predictions` also contains prediction scores and labels.\n" + "`prediction` also contains prediction scores and labels.\n" ] }, { @@ -482,8 +455,8 @@ }, "outputs": [], "source": [ - "pred_score = predictions.pred_score[0]\n", - "pred_labels = predictions.pred_label[0]\n", + "pred_score = prediction.pred_score[0]\n", + "pred_labels = prediction.pred_label[0]\n", "print(pred_score, pred_labels)" ] }, @@ -509,7 +482,7 @@ }, "outputs": [], "source": [ - "pred_masks = predictions.pred_mask[0].squeeze().cpu().numpy()\n", + "pred_masks = prediction.pred_mask[0].squeeze().cpu().numpy()\n", "plt.imshow(pred_masks)" ] }, From 74f565bbcdfbb8bbde2f937bd86fcd07bb7ef638 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Thu, 19 Dec 2024 19:03:28 +0100 Subject: [PATCH 16/22] fix logic in _update_augmentations --- src/anomalib/data/datamodules/base/image.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 09e482e2f4..adb68c8338 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -123,13 +123,13 @@ def _update_augmentations(self) -> None: subset = getattr(self, f"{subset_name}_data", None) augmentations = getattr(self, f"{subset_name}_augmentations", None) model_transform = get_nested_attr(self, "trainer.model.pre_processor.transform") - if subset and augmentations: + if subset and model_transform: self._update_subset_augmentations(subset, augmentations, model_transform) + @staticmethod def _update_subset_augmentations( - self, dataset: AnomalibDataset, - augmentations: Transform, + augmentations: Transform | None, model_transform: Transform, ) -> None: """Update the augmentations of the dataset. @@ -147,7 +147,7 @@ def _update_subset_augmentations( if model_resizes: model_resize = model_resizes[0] - for aug_resize in self.get_resize_transforms(augmentations): # warn user if resizes inconsistent + for aug_resize in get_transforms_of_type(augmentations, Resize): # warn user if resizes inconsistent if model_resize.size != aug_resize.size: msg = f"Conflicting resize shapes found between augmentations and model transforms. You are using \ a Resize transform in your input data augmentations. Please be aware that the model also \ From 84d31208b2a62cb5b68a2706b80950787f0997ae Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 20 Dec 2024 12:26:32 +0100 Subject: [PATCH 17/22] add unit tests for updating augmentations in data module --- src/anomalib/data/datamodules/base/image.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 9e296cbfb6..0c57d90984 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -109,7 +109,7 @@ def __init__( self.num_workers = num_workers self.test_split_mode = TestSplitMode(test_split_mode) if test_split_mode else TestSplitMode.NONE self.test_split_ratio = test_split_ratio or 0.5 - self.val_split_mode = ValSplitMode(val_split_mode) + self.val_split_mode = ValSplitMode(val_split_mode) if val_split_mode else ValSplitMode.NONE self.val_split_ratio = val_split_ratio or 0.5 self.seed = seed @@ -210,9 +210,12 @@ def _update_subset_augmentations( antialiasing settings can lead to unexpected behaviour, so it is recommended to use the same \ antialiasing setting between augmentations and model transforms. Augmentations: \ antialias={aug_resize.antialias}, Model transforms: antialias={model_resize.antialias}" + logger.warning(msg) # append model resize to augmentations - if isinstance(augmentations, Compose): + if isinstance(augmentations, Resize): + augmentations = model_resize + elif isinstance(augmentations, Compose): augmentations = Compose([*augmentations.transforms, model_resize]) elif isinstance(augmentations, Transform): augmentations = Compose([augmentations, model_resize]) From 1d9ea319f81790a242bd22a1952b2f814f56396b Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 20 Dec 2024 13:10:56 +0100 Subject: [PATCH 18/22] add unit tests for updating augmentations in data module --- .../datamodule/test_update_augmentations.py | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 tests/unit/data/datamodule/test_update_augmentations.py diff --git a/tests/unit/data/datamodule/test_update_augmentations.py b/tests/unit/data/datamodule/test_update_augmentations.py new file mode 100644 index 0000000000..fb9ed3e206 --- /dev/null +++ b/tests/unit/data/datamodule/test_update_augmentations.py @@ -0,0 +1,122 @@ +"""Tests for the _update_subset_augmentations method in AnomalibDataModule.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import pytest +from torchvision.transforms.v2 import ( + Compose, + InterpolationMode, + Normalize, + RandomHorizontalFlip, + RandomVerticalFlip, + Resize, +) + +from anomalib.data import AnomalibDataModule, AnomalibDataset + + +class DummyDataset(AnomalibDataset): + """Dummy dataset class for testing.""" + + def __init__(self) -> None: + pass + + +class TestUpdateAugmentations: + """Tests for the _update_subset_augmentations method in AnomalibDataModule.""" + + @staticmethod + def test_conflicting_shape(caplog: pytest.LogCaptureFixture) -> None: + """Test that a warning is logged if resize shapes mismatch.""" + dataset = DummyDataset() + model_transform = Resize((224, 224)) + augmentations = Resize((256, 256)) + + with caplog.at_level(logging.WARNING): + AnomalibDataModule._update_subset_augmentations(dataset, augmentations, model_transform=model_transform) # noqa: SLF001 + # check if a warning was logged + assert any(record.levelname == "WARNING" for record in caplog.records) + assert "Conflicting resize shape" in caplog.text + # check if augmentations were overwritten by model transform + assert dataset.augmentations == model_transform + + @staticmethod + def test_conflicting_interpolation(caplog: pytest.LogCaptureFixture) -> None: + """Test that a warning is logged if interpolation methods mismatch.""" + dataset = DummyDataset() + model_transform = Resize((224, 224), interpolation=InterpolationMode.BILINEAR) + augmentations = Resize((224, 224), interpolation=InterpolationMode.NEAREST) + + with caplog.at_level(logging.WARNING): + AnomalibDataModule._update_subset_augmentations(dataset, augmentations, model_transform=model_transform) # noqa: SLF001 + # check if a warning was logged + assert any(record.levelname == "WARNING" for record in caplog.records) + assert "Conflicting interpolation method" in caplog.text + # check if augmentations were overwritten by model transform + assert dataset.augmentations == model_transform + + @staticmethod + def test_conflicting_antialias(caplog: pytest.LogCaptureFixture) -> None: + """Test that a warning is logged if antialiasing setting mismatch.""" + dataset = DummyDataset() + model_transform = Resize((224, 224), antialias=True) + augmentations = Resize((224, 224), antialias=False) + + with caplog.at_level(logging.WARNING): + AnomalibDataModule._update_subset_augmentations(dataset, augmentations, model_transform=model_transform) # noqa: SLF001 + # check if a warning was logged + assert any(record.levelname == "WARNING" for record in caplog.records) + assert "Conflicting antialiasing setting" in caplog.text + # check if augmentations were overwritten by model transform + assert dataset.augmentations == model_transform + + @staticmethod + def test_augmentations_as_compose() -> None: + """Test that the Resize transform is added to the augmentations if augmentations is a Compose object.""" + dataset = DummyDataset() + model_transform = Resize((224, 224)) + augmentations = Compose([RandomHorizontalFlip(), RandomVerticalFlip()]) + + AnomalibDataModule._update_subset_augmentations(dataset, augmentations, model_transform=model_transform) # noqa: SLF001 + assert dataset.augmentations.transforms[-1] == model_transform + + @staticmethod + def test_augmentations_as_transform() -> None: + """Test that the Resize transform is added to the augmentations if augmentations is a single transform.""" + dataset = DummyDataset() + model_transform = Resize((224, 224)) + augmentations = RandomHorizontalFlip() + + AnomalibDataModule._update_subset_augmentations(dataset, augmentations, model_transform=model_transform) # noqa: SLF001 + assert dataset.augmentations.transforms[-1] == model_transform + + @staticmethod + def test_model_transform_as_compose() -> None: + """Test that the Resize transform is added to the augmentations if model_transform is a Compose object.""" + dataset = DummyDataset() + model_transform = Compose([Resize(224, 224), Normalize(mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])]) + augmentations = Compose([RandomHorizontalFlip(), RandomVerticalFlip()]) + + AnomalibDataModule._update_subset_augmentations(dataset, augmentations, model_transform=model_transform) # noqa: SLF001 + assert dataset.augmentations.transforms[-1] == model_transform.transforms[0] + + @staticmethod + def test_no_model_transforms() -> None: + """Test that the augmentations are added but not modified if model_transform is None.""" + dataset = DummyDataset() + augmentations = Compose([RandomHorizontalFlip(), RandomVerticalFlip()]) + + AnomalibDataModule._update_subset_augmentations(dataset, augmentations, model_transform=None) # noqa: SLF001 + assert dataset.augmentations == augmentations + + @staticmethod + def test_no_augmentations() -> None: + """Test that the model_transform resize is added to the augmentations if augmentations is None.""" + dataset = DummyDataset() + model_transform = Resize((224, 224)) + + AnomalibDataModule._update_subset_augmentations(dataset, augmentations=None, model_transform=model_transform) # noqa: SLF001 + assert dataset.augmentations == model_transform From d85e25d6e5c2b278eaa656370f5e0561cb3e4210 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 20 Dec 2024 13:36:36 +0100 Subject: [PATCH 19/22] add test for collate method --- tests/unit/data/dataclasses/test_collate.py | 45 +++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 tests/unit/data/dataclasses/test_collate.py diff --git a/tests/unit/data/dataclasses/test_collate.py b/tests/unit/data/dataclasses/test_collate.py new file mode 100644 index 0000000000..aaf83c87ff --- /dev/null +++ b/tests/unit/data/dataclasses/test_collate.py @@ -0,0 +1,45 @@ +"""Tests for the collating DatasetItems into Batches.""" + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import torch +from torchvision.tv_tensors import Image, Mask + +from anomalib.data.dataclasses.generic import BatchIterateMixin + + +@dataclass +class DummyDatasetItem: + """Dummy dataset item with image and mask.""" + + image: Image + mask: Mask + + +@dataclass +class DummyBatch(BatchIterateMixin[DummyDatasetItem]): + """Dummy batch with image and mask.""" + + item_class = DummyDatasetItem + image: Image + mask: Mask + + +def test_collate_heterogeneous_shapes() -> None: + """Test collating items with different shapes.""" + items = [ + DummyDatasetItem( + image=Image(torch.rand((3, 256, 256))), + mask=Mask(torch.ones((256, 256))), + ), + DummyDatasetItem( + image=Image(torch.rand((3, 224, 224))), + mask=Mask(torch.ones((224, 224))), + ), + ] + batch = DummyBatch.collate(items) + # the collated batch should have the shape of the largest item + assert batch.image.shape == (2, 3, 256, 256) From 03ce13479a5a23e5aa639a6949ab20c4042ba8d7 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 27 Dec 2024 12:22:33 +0100 Subject: [PATCH 20/22] copy transform before converting --- src/anomalib/pre_processing/utils/transform.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/anomalib/pre_processing/utils/transform.py b/src/anomalib/pre_processing/utils/transform.py index bc6ecd97ee..3b6918abb5 100644 --- a/src/anomalib/pre_processing/utils/transform.py +++ b/src/anomalib/pre_processing/utils/transform.py @@ -7,6 +7,8 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import copy + from torchvision.transforms.v2 import CenterCrop, Compose, Resize, Transform from anomalib.data.transforms import ExportableCenterCrop @@ -45,6 +47,7 @@ def get_exportable_transform(transform: Transform | None) -> Transform | None: """ if transform is None: return None + transform = copy.deepcopy(transform) transform = disable_antialiasing(transform) return convert_center_crop_transform(transform) From 15116217af2995eccd16521ad3c7bd43cba311fb Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 3 Jan 2025 18:33:51 +0100 Subject: [PATCH 21/22] update docstring --- src/anomalib/utils/attrs.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/anomalib/utils/attrs.py b/src/anomalib/utils/attrs.py index 35f3971485..af53ca01e3 100644 --- a/src/anomalib/utils/attrs.py +++ b/src/anomalib/utils/attrs.py @@ -9,6 +9,9 @@ def get_nested_attr(obj: Any, attr_path: str, default: Any | None = None) -> Any: # noqa: ANN401 """Safely retrieves a nested attribute from an object. + This function helps reduce boilerplate code when working with nested attributes, by allowing you to retrieve a + nested attribute with a single function call instead of multiple nested calls to `getattr`. + Args: obj: The object to retrieve the attribute from. attr_path: A dot-separated string representing the attribute path. @@ -16,6 +19,31 @@ def get_nested_attr(obj: Any, attr_path: str, default: Any | None = None) -> Any Returns: The value of the nested attribute, or `default` if any attribute in the path is missing. + + Example: + >>> class A: + ... def __init__(self, b): + ... self.b = b + >>> + >>> class B: + ... def __init__(self, c): + ... self.c = c + >>> + >>> class C: + ... def __init__(self, d): + ... self.d = d + >>> + >>> d = 42 + >>> c = C(d) + >>> b = B(c) + >>> a = A(b) + >>> get_nested_attr(a, "b.c.d") # 42 + >>> # this is equivalent to: + >>> # getattr(getattr(getattr(a, "b", None), "c", None), "value", None) + >>> + >>> get_nested_attr(a, "b.c.foo") # None + >>> get_nested_attr(a, "b.c.foo", "bar") # "bar" + >>> get_nested_attr(a, "b.d.c") # None """ for attr in attr_path.split("."): obj = getattr(obj, attr, default) From bebe1ea819d8ef9bd2c3c02363aebec648e931d8 Mon Sep 17 00:00:00 2001 From: Dick Ameln Date: Fri, 3 Jan 2025 18:40:24 +0100 Subject: [PATCH 22/22] rename function --- src/anomalib/data/datamodules/base/image.py | 6 +++--- src/anomalib/data/transforms/utils.py | 7 +++++-- src/anomalib/models/image/efficient_ad/lightning_model.py | 4 ++-- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/anomalib/data/datamodules/base/image.py b/src/anomalib/data/datamodules/base/image.py index 0c57d90984..a2e163a3bd 100644 --- a/src/anomalib/data/datamodules/base/image.py +++ b/src/anomalib/data/datamodules/base/image.py @@ -39,7 +39,7 @@ from anomalib import TaskType from anomalib.data.datasets.base.image import AnomalibDataset -from anomalib.data.transforms.utils import get_transforms_of_type +from anomalib.data.transforms.utils import extract_transforms_by_type from anomalib.data.utils import TestSplitMode, ValSplitMode, random_split, split_by_label from anomalib.data.utils.synthetic import SyntheticAnomalyDataset from anomalib.utils.attrs import get_nested_attr @@ -182,11 +182,11 @@ def _update_subset_augmentations( augmentations (Transform): Augmentations to apply to the dataset. model_transform (Transform): Transform object from the model PreProcessor. """ - model_resizes = get_transforms_of_type(model_transform, Resize) + model_resizes = extract_transforms_by_type(model_transform, Resize) if model_resizes: model_resize = model_resizes[0] - for aug_resize in get_transforms_of_type(augmentations, Resize): # warn user if resizes inconsistent + for aug_resize in extract_transforms_by_type(augmentations, Resize): # warn user if resizes inconsistent if model_resize.size != aug_resize.size: msg = f"Conflicting resize shapes found between augmentations and model transforms. You are using \ a Resize transform in your input data augmentations. Please be aware that the model also \ diff --git a/src/anomalib/data/transforms/utils.py b/src/anomalib/data/transforms/utils.py index b1e21687a5..5ef1e9b0ec 100644 --- a/src/anomalib/data/transforms/utils.py +++ b/src/anomalib/data/transforms/utils.py @@ -6,8 +6,11 @@ from torchvision.transforms.v2 import Compose, Transform -def get_transforms_of_type(input_transform: Transform | None, transform_type: type[Transform]) -> list[type[Transform]]: - """Retrieves all transforms of a given type from a transform or transform composition. +def extract_transforms_by_type( + input_transform: Transform | None, + transform_type: type[Transform], +) -> list[type[Transform]]: + """Extracts all transforms of a given type from a transform or transform composition. Args: input_transform (Transform): Torchvision Transform instance. diff --git a/src/anomalib/models/image/efficient_ad/lightning_model.py b/src/anomalib/models/image/efficient_ad/lightning_model.py index 8f18245d8f..5a0c6c5ee3 100644 --- a/src/anomalib/models/image/efficient_ad/lightning_model.py +++ b/src/anomalib/models/image/efficient_ad/lightning_model.py @@ -49,7 +49,7 @@ from anomalib import LearningType from anomalib.data import Batch -from anomalib.data.transforms.utils import get_transforms_of_type +from anomalib.data.transforms.utils import extract_transforms_by_type from anomalib.data.utils import DownloadInfo, download_and_extract from anomalib.metrics import Evaluator from anomalib.models.components import AnomalibModule @@ -366,7 +366,7 @@ def on_train_start(self) -> None: msg = "train_batch_size for EfficientAd should be 1." raise ValueError(msg) - if self.pre_processor and get_transforms_of_type(self.pre_processor.transform, Normalize): + if self.pre_processor and extract_transforms_by_type(self.pre_processor.transform, Normalize): msg = "Transforms for EfficientAd should not contain Normalize." raise ValueError(msg)