diff --git a/README.md b/README.md index 7a04992..b79143b 100644 --- a/README.md +++ b/README.md @@ -91,7 +91,7 @@ Neural networks come from different sources. With `thingsvision`, you can extrac ### :computer: Setting up your environment -#### Working locally. +#### Working locally First, create a new `conda environment` with Python version 3.8, 3.9, or 3.10 e.g. by using `conda`: ```bash @@ -121,7 +121,7 @@ $ pip install dreamsim==0.1.2 See the [docs](https://vicco-group.github.io/thingsvision/AvailableModels.html#dreamsim) for which `DreamSim` models are available in `thingsvision`. -#### Google Colab. +#### Google Colab Alternatively, you can use Google Colab to play around with `thingsvision` by uploading your image data to Google Drive (via directory mounting). You can find the jupyter notebook using `PyTorch` [here](https://colab.research.google.com/github/ViCCo-Group/thingsvision/blob/master/notebooks/pytorch.ipynb) and the `TensorFlow` example [here](https://colab.research.google.com/github/ViCCo-Group/thingsvision/blob/master/notebooks/tensorflow.ipynb).

(back to top)

@@ -207,6 +207,48 @@ features = extractor.extract_features( save_features(features, out_path='path/to/features', file_format='npy') # file_format can be set to "npy", "txt", "mat", "pt", or "hdf5" ``` +#### Feature extraction with custom data pipeline + +##### PyTorch + +```python +module_name = 'visual' + +# your custom dataset and dataloader classes come here (for example, a PyTorch data loader) +my_dataset = ... +my_dataloader = ... + +with extractor.batch_extraction(module_name, output_type="tensor") as e: + for batch in my_dataloader: + ... # whatever preprocessing you want to add to the batch + feature_batch = e.extract_batch( + batch=batch, + flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer + ) + ... # whatever post-processing you want to add to the extracted features +``` + +##### TensorFlow / Keras + +```python +module_name = 'visual' + +# your custom dataset and dataloader classes come here (for example, TFRecords files) +my_dataset = ... +my_dataloader = ... + +for batch in my_dataloader: + ... # whatever preprocessing you want to add to the batch + feature_batch = extractor.extract_batch( + batch=batch, + module_name=module_name, + flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer + ) + ... # whatever post-processing you want to add to the extracted features +``` + +#### Human alignment + *Human alignment*: If you want to align the extracted features with human object similarity according to the approach introduced in *[Improving neural network representations using human similiarty judgments](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9febda1c8344cc5f2d51713964864e93-Abstract-Conference.html)* you can optionally `align` the extracted features using the following method: ```python diff --git a/docs/GettingStarted.md b/docs/GettingStarted.md index 393e813..e36f66e 100644 --- a/docs/GettingStarted.md +++ b/docs/GettingStarted.md @@ -5,7 +5,8 @@ nav_order: 2 # Getting started ## Setting up your environment -### Working locally. + +### Working locally First, create a new `conda environment` with Python version 3.8, 3.9, or 3.10 e.g. by using `conda`: ```bash @@ -107,8 +108,9 @@ features = extractor.extract_features( save_features(features, out_path='path/to/features', file_format='npy') # file_format can be set to "npy", "txt", "mat", "pt", or "hdf5" ``` -### Extraction with custom data pipeline and training loop +### Extraction with custom data pipeline +#### PyTorch ```python module_name = 'visual' @@ -117,14 +119,31 @@ module_name = 'visual' my_dataset = ... my_dataloader = ... -# your custom training loop comes here +with extractor.batch_extraction(module_name, output_type="tensor") as e: + for batch in my_dataloader: + ... # whatever preprocessing you want to add to the batch + feature_batch = e.extract_batch( + batch=batch, + flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer + ) + ... # whatever post-processing you want to add to the extracted features +``` + +#### TensorFlow / Keras + +```python +module_name = 'visual' + +# your custom dataset and dataloader classes come here (for example, TFRecords files) +my_dataset = ... +my_dataloader = ... + for batch in my_dataloader: ... # whatever preprocessing you want to add to the batch feature_batch = extractor.extract_batch( batch=batch, module_name=module_name, flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer - output_type="tensor", # optionally set the output type of the feature matrix ) ... # whatever post-processing you want to add to the extracted features ``` diff --git a/tests/extractor/extraction/test_torch_vs_tensorflow.py b/tests/extractor/extraction/test_torch_vs_tensorflow.py index 41bc732..ee3c0f4 100644 --- a/tests/extractor/extraction/test_torch_vs_tensorflow.py +++ b/tests/extractor/extraction/test_torch_vs_tensorflow.py @@ -1,11 +1,11 @@ import unittest -import torch import numpy as np +import torch + import tests.helper as helper -from thingsvision.utils.data import DataLoader import thingsvision.core.extraction.helpers as core_helpers - +from thingsvision.utils.data import DataLoader class ExtractionPTvsTFTestCase(unittest.TestCase): @@ -45,8 +45,8 @@ def test_custom_torch_vs_tf_extraction(self): pt_model.backend = pt_backend layer_name = "relu" - expected_features_pt = torch.tensor([[2., 2.], [0., 0.]]) - expected_features_tf = np.array([[2., 2.], [0, 0.]]) + expected_features_tf = np.array([[2.0, 2.0], [0, 0.0]]) + expected_features_pt = torch.tensor([[2.0, 2.0], [0.0, 0.0]]) for i, batch in enumerate(tf_dl): tf_features = tf_model.extract_batch( @@ -54,33 +54,32 @@ def test_custom_torch_vs_tf_extraction(self): module_name=layer_name, flatten_acts=False, ) - np.testing.assert_allclose(tf_features, expected_features_tf[i][None,:]) - - for i, batch in enumerate(pt_dl): - pt_features = pt_model.extract_batch( - batch=batch, - module_name=layer_name, - flatten_acts=False, - output_type="tensor", - ) - np.testing.assert_allclose(pt_features, expected_features_pt[i][None,:]) + expected_features = expected_features_tf[i][None, :] + np.testing.assert_allclose(tf_features, expected_features) + + with pt_model.batch_extraction(layer_name, output_type="tensor") as e: + for i, batch in enumerate(pt_dl): + pt_features = e.extract_batch( + batch=batch, + flatten_acts=False, + ) + expected_features = expected_features_pt[i][None, :] + np.testing.assert_allclose(pt_features, expected_features) layer_name = "relu2" - expected_features = np.array([[4., 4.], [0., 0.]]) + expected_features = np.array([[4.0, 4.0], [0.0, 0.0]]) for i, batch in enumerate(tf_dl): tf_features = tf_model.extract_batch( batch=batch, module_name=layer_name, flatten_acts=False, ) - np.testing.assert_allclose(tf_features, expected_features[i][None,:]) - - for i, batch in enumerate(pt_dl): - pt_features = pt_model.extract_batch( - batch=batch, - module_name=layer_name, - flatten_acts=False, - output_type="ndarray", - ) - np.testing.assert_allclose(pt_features, expected_features[i][None,:]) - + np.testing.assert_allclose(tf_features, expected_features[i][None, :]) + + with pt_model.batch_extraction(layer_name, output_type="ndarray") as e: + for i, batch in enumerate(pt_dl): + pt_features = e.extract_batch( + batch=batch, + flatten_acts=False, + ) + np.testing.assert_allclose(pt_features, expected_features[i][None, :]) diff --git a/thingsvision/_version.py b/thingsvision/_version.py index 7a2056f..667b52f 100644 --- a/thingsvision/_version.py +++ b/thingsvision/_version.py @@ -1 +1 @@ -__version__ = "2.5.1" +__version__ = "2.5.2" diff --git a/thingsvision/core/extraction/torch.py b/thingsvision/core/extraction/torch.py index 2f4e183..1da012e 100644 --- a/thingsvision/core/extraction/torch.py +++ b/thingsvision/core/extraction/torch.py @@ -57,7 +57,7 @@ def hook(model, input, output) -> None: return hook - def register_hook(self, module_name: str) -> None: + def _register_hook(self, module_name: str) -> None: """Register a forward hook to store activations.""" for n, m in self.model.named_modules(): if n == module_name: @@ -67,24 +67,26 @@ def register_hook(self, module_name: str) -> None: def _unregister_hook(self) -> None: self.hook_handle.remove() + def batch_extraction(self, module_name: str, output_type: str) -> object: + return BatchExtraction( + extractor=self, module_name=module_name, output_type=output_type + ) + def extract_batch( self, batch: TensorType["b", "c", "h", "w"], - module_name: str, flatten_acts: bool, - output_type: str = "tensor", ) -> Union[ TensorType["b", "num_maps", "h_prime", "w_prime"], TensorType["b", "t", "d"], TensorType["b", "p"], TensorType["b", "d"], ]: - self._module_and_output_check(module_name, output_type) - self.register_hook(module_name=module_name) - act = self._extract_batch(batch, module_name, flatten_acts) - if output_type == "ndarray": + act = self._extract_batch( + batch=batch, module_name=self.module_name, flatten_acts=flatten_acts + ) + if self.output_type == "ndarray": act = self._to_numpy(act) - self._unregister_hook() return act @torch.no_grad() @@ -128,7 +130,7 @@ def extract_features( ): self.model = self.model.to(self.device) self.activations = {} - self.register_hook(module_name=module_name) + self._register_hook(module_name=module_name) features = super().extract_features( batches=batches, module_name=module_name, @@ -230,3 +232,25 @@ def get_default_transformation( def get_backend(self) -> str: return "pt" + + +class BatchExtraction(object): + + def __init__( + self, extractor: PyTorchExtractor, module_name: str, output_type: str + ) -> None: + self.extractor = extractor + self.module_name = module_name + self.output_type = output_type + + def __enter__(self) -> PyTorchExtractor: + self.extractor._module_and_output_check(self.module_name, self.output_type) + self.extractor._register_hook(self.module_name) + setattr(self.extractor, "module_name", self.module_name) + setattr(self.extractor, "output_type", self.output_type) + return self.extractor + + def __exit__(self, *args): + self.extractor._unregister_hook() + delattr(self.extractor, "module_name") + delattr(self.extractor, "output_type")