Skip to content

Commit

Permalink
Merge pull request #160 from ViCCo-Group/with_batch_extraction
Browse files Browse the repository at this point in the history
implemented mini-batch extraction with with statement for PyTorch
  • Loading branch information
LukasMut authored Apr 5, 2024
2 parents 4ceafff + 2bf4edb commit 9e6344b
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 43 deletions.
46 changes: 44 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ Neural networks come from different sources. With `thingsvision`, you can extrac

<!-- Setting up your environment -->
### :computer: Setting up your environment
#### Working locally.
#### Working locally
First, create a new `conda environment` with Python version 3.8, 3.9, or 3.10 e.g. by using `conda`:

```bash
Expand Down Expand Up @@ -121,7 +121,7 @@ $ pip install dreamsim==0.1.2

See the [docs](https://vicco-group.github.io/thingsvision/AvailableModels.html#dreamsim) for which `DreamSim` models are available in `thingsvision`.

#### Google Colab.
#### Google Colab
Alternatively, you can use Google Colab to play around with `thingsvision` by uploading your image data to Google Drive (via directory mounting).
You can find the jupyter notebook using `PyTorch` [here](https://colab.research.google.com/github/ViCCo-Group/thingsvision/blob/master/notebooks/pytorch.ipynb) and the `TensorFlow` example [here](https://colab.research.google.com/github/ViCCo-Group/thingsvision/blob/master/notebooks/tensorflow.ipynb).
<p align="right">(<a href="#readme-top">back to top</a>)</p>
Expand Down Expand Up @@ -207,6 +207,48 @@ features = extractor.extract_features(
save_features(features, out_path='path/to/features', file_format='npy') # file_format can be set to "npy", "txt", "mat", "pt", or "hdf5"
```

#### Feature extraction with custom data pipeline

##### PyTorch

```python
module_name = 'visual'

# your custom dataset and dataloader classes come here (for example, a PyTorch data loader)
my_dataset = ...
my_dataloader = ...

with extractor.batch_extraction(module_name, output_type="tensor") as e:
for batch in my_dataloader:
... # whatever preprocessing you want to add to the batch
feature_batch = e.extract_batch(
batch=batch,
flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer
)
... # whatever post-processing you want to add to the extracted features
```

##### TensorFlow / Keras

```python
module_name = 'visual'

# your custom dataset and dataloader classes come here (for example, TFRecords files)
my_dataset = ...
my_dataloader = ...

for batch in my_dataloader:
... # whatever preprocessing you want to add to the batch
feature_batch = extractor.extract_batch(
batch=batch,
module_name=module_name,
flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer
)
... # whatever post-processing you want to add to the extracted features
```

#### Human alignment

*Human alignment*: If you want to align the extracted features with human object similarity according to the approach introduced in *[Improving neural network representations using human similiarty judgments](https://proceedings.neurips.cc/paper_files/paper/2023/hash/9febda1c8344cc5f2d51713964864e93-Abstract-Conference.html)* you can optionally `align` the extracted features using the following method:

```python
Expand Down
27 changes: 23 additions & 4 deletions docs/GettingStarted.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ nav_order: 2
# Getting started

## Setting up your environment
### Working locally.

### Working locally
First, create a new `conda environment` with Python version 3.8, 3.9, or 3.10 e.g. by using `conda`:

```bash
Expand Down Expand Up @@ -107,8 +108,9 @@ features = extractor.extract_features(
save_features(features, out_path='path/to/features', file_format='npy') # file_format can be set to "npy", "txt", "mat", "pt", or "hdf5"
```

### Extraction with custom data pipeline and training loop
### Extraction with custom data pipeline

#### PyTorch

```python
module_name = 'visual'
Expand All @@ -117,14 +119,31 @@ module_name = 'visual'
my_dataset = ...
my_dataloader = ...

# your custom training loop comes here
with extractor.batch_extraction(module_name, output_type="tensor") as e:
for batch in my_dataloader:
... # whatever preprocessing you want to add to the batch
feature_batch = e.extract_batch(
batch=batch,
flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer
)
... # whatever post-processing you want to add to the extracted features
```

#### TensorFlow / Keras

```python
module_name = 'visual'

# your custom dataset and dataloader classes come here (for example, TFRecords files)
my_dataset = ...
my_dataloader = ...

for batch in my_dataloader:
... # whatever preprocessing you want to add to the batch
feature_batch = extractor.extract_batch(
batch=batch,
module_name=module_name,
flatten_acts=True, # flatten 2D feature maps from an early convolutional or attention layer
output_type="tensor", # optionally set the output type of the feature matrix
)
... # whatever post-processing you want to add to the extracted features
```
Expand Down
53 changes: 26 additions & 27 deletions tests/extractor/extraction/test_torch_vs_tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import unittest
import torch

import numpy as np
import torch

import tests.helper as helper
from thingsvision.utils.data import DataLoader
import thingsvision.core.extraction.helpers as core_helpers

from thingsvision.utils.data import DataLoader


class ExtractionPTvsTFTestCase(unittest.TestCase):
Expand Down Expand Up @@ -45,42 +45,41 @@ def test_custom_torch_vs_tf_extraction(self):
pt_model.backend = pt_backend

layer_name = "relu"
expected_features_pt = torch.tensor([[2., 2.], [0., 0.]])
expected_features_tf = np.array([[2., 2.], [0, 0.]])
expected_features_tf = np.array([[2.0, 2.0], [0, 0.0]])
expected_features_pt = torch.tensor([[2.0, 2.0], [0.0, 0.0]])

for i, batch in enumerate(tf_dl):
tf_features = tf_model.extract_batch(
batch=batch,
module_name=layer_name,
flatten_acts=False,
)
np.testing.assert_allclose(tf_features, expected_features_tf[i][None,:])

for i, batch in enumerate(pt_dl):
pt_features = pt_model.extract_batch(
batch=batch,
module_name=layer_name,
flatten_acts=False,
output_type="tensor",
)
np.testing.assert_allclose(pt_features, expected_features_pt[i][None,:])
expected_features = expected_features_tf[i][None, :]
np.testing.assert_allclose(tf_features, expected_features)

with pt_model.batch_extraction(layer_name, output_type="tensor") as e:
for i, batch in enumerate(pt_dl):
pt_features = e.extract_batch(
batch=batch,
flatten_acts=False,
)
expected_features = expected_features_pt[i][None, :]
np.testing.assert_allclose(pt_features, expected_features)

layer_name = "relu2"
expected_features = np.array([[4., 4.], [0., 0.]])
expected_features = np.array([[4.0, 4.0], [0.0, 0.0]])
for i, batch in enumerate(tf_dl):
tf_features = tf_model.extract_batch(
batch=batch,
module_name=layer_name,
flatten_acts=False,
)
np.testing.assert_allclose(tf_features, expected_features[i][None,:])

for i, batch in enumerate(pt_dl):
pt_features = pt_model.extract_batch(
batch=batch,
module_name=layer_name,
flatten_acts=False,
output_type="ndarray",
)
np.testing.assert_allclose(pt_features, expected_features[i][None,:])

np.testing.assert_allclose(tf_features, expected_features[i][None, :])

with pt_model.batch_extraction(layer_name, output_type="ndarray") as e:
for i, batch in enumerate(pt_dl):
pt_features = e.extract_batch(
batch=batch,
flatten_acts=False,
)
np.testing.assert_allclose(pt_features, expected_features[i][None, :])
2 changes: 1 addition & 1 deletion thingsvision/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.5.1"
__version__ = "2.5.2"
42 changes: 33 additions & 9 deletions thingsvision/core/extraction/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def hook(model, input, output) -> None:

return hook

def register_hook(self, module_name: str) -> None:
def _register_hook(self, module_name: str) -> None:
"""Register a forward hook to store activations."""
for n, m in self.model.named_modules():
if n == module_name:
Expand All @@ -67,24 +67,26 @@ def register_hook(self, module_name: str) -> None:
def _unregister_hook(self) -> None:
self.hook_handle.remove()

def batch_extraction(self, module_name: str, output_type: str) -> object:
return BatchExtraction(
extractor=self, module_name=module_name, output_type=output_type
)

def extract_batch(
self,
batch: TensorType["b", "c", "h", "w"],
module_name: str,
flatten_acts: bool,
output_type: str = "tensor",
) -> Union[
TensorType["b", "num_maps", "h_prime", "w_prime"],
TensorType["b", "t", "d"],
TensorType["b", "p"],
TensorType["b", "d"],
]:
self._module_and_output_check(module_name, output_type)
self.register_hook(module_name=module_name)
act = self._extract_batch(batch, module_name, flatten_acts)
if output_type == "ndarray":
act = self._extract_batch(
batch=batch, module_name=self.module_name, flatten_acts=flatten_acts
)
if self.output_type == "ndarray":
act = self._to_numpy(act)
self._unregister_hook()
return act

@torch.no_grad()
Expand Down Expand Up @@ -128,7 +130,7 @@ def extract_features(
):
self.model = self.model.to(self.device)
self.activations = {}
self.register_hook(module_name=module_name)
self._register_hook(module_name=module_name)
features = super().extract_features(
batches=batches,
module_name=module_name,
Expand Down Expand Up @@ -230,3 +232,25 @@ def get_default_transformation(

def get_backend(self) -> str:
return "pt"


class BatchExtraction(object):

def __init__(
self, extractor: PyTorchExtractor, module_name: str, output_type: str
) -> None:
self.extractor = extractor
self.module_name = module_name
self.output_type = output_type

def __enter__(self) -> PyTorchExtractor:
self.extractor._module_and_output_check(self.module_name, self.output_type)
self.extractor._register_hook(self.module_name)
setattr(self.extractor, "module_name", self.module_name)
setattr(self.extractor, "output_type", self.output_type)
return self.extractor

def __exit__(self, *args):
self.extractor._unregister_hook()
delattr(self.extractor, "module_name")
delattr(self.extractor, "output_type")

0 comments on commit 9e6344b

Please sign in to comment.