Skip to content

Commit 74ad9d0

Browse files
[docs] Add pytorchvideo docs
Add pytorchvideo tutorial docs for using a pytorchvideo model as an encoder with TorchVideoEncoder class. ghstack-source-id: 9cc594f Pull Request resolved: #1163
1 parent 81e2e43 commit 74ad9d0

File tree

4 files changed

+185
-6
lines changed

4 files changed

+185
-6
lines changed

mmf/modules/encoders.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from copy import deepcopy
1010
from dataclasses import asdict, dataclass
1111
from enum import Enum
12-
from typing import Any
12+
from typing import Any, Optional
1313

1414
import torch
1515
import torchvision
@@ -773,6 +773,31 @@ def filter_dict_to_signature(self, callable, params):
773773
return accepted_params, ignored_params
774774

775775

776+
@registry.register_encoder("mvit")
777+
class MViTEncoder(Encoder):
778+
"""
779+
MVIT from pytorchvideo
780+
"""
781+
782+
@dataclass
783+
class Config(Encoder.Config):
784+
name: str = "mvit"
785+
random_init: bool = False
786+
model_name: str = "multiscale_vision_transformers"
787+
spatial_size: int = 224
788+
temporal_size: int = 8
789+
head: Optional[Any] = None
790+
791+
def __init__(self, config: Config):
792+
super().__init__()
793+
self.encoder = TorchVideoEncoder(config)
794+
795+
def forward(self, *args, **kwargs):
796+
output = self.encoder(*args, **kwargs)
797+
output = output.permute(0, 2, 1)
798+
return output[:, :1, :]
799+
800+
776801
@registry.register_encoder("r2plus1d_18")
777802
class R2Plus1D18VideoEncoder(PooledEncoder):
778803
"""

tests/models/test_mmf_transformer.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from mmf.utils.configuration import Configuration
2222
from mmf.utils.env import setup_imports, teardown_imports
2323
from omegaconf import OmegaConf
24-
24+
from tests.test_utils import (
25+
skip_if_no_pytorchvideo,
26+
)
2527

2628
BERT_VOCAB_SIZE = 30255
2729
ROBERTA_VOCAB_SIZE = 50265
@@ -444,6 +446,57 @@ def test_preprocessing_with_resnet_encoder(self):
444446
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
445447
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())
446448

449+
@skip_if_no_pytorchvideo
450+
def test_preprocessing_with_mvit_encoder(self):
451+
encoder_config = OmegaConf.create(
452+
{
453+
"name": "mvit",
454+
"model_name": "multiscale_vision_transformers",
455+
"random_init": True,
456+
"cls_layer_num": 0,
457+
"spatial_size": 224,
458+
"temporal_size": 8,
459+
"head": None,
460+
}
461+
)
462+
self._image_modality_config = MMFTransformerModalityConfig(
463+
type="image",
464+
key="image",
465+
embedding_dim=12545,
466+
position_dim=1,
467+
segment_id=0,
468+
encoder=encoder_config,
469+
)
470+
modalities_config = [self._image_modality_config, self._text_modality_config]
471+
config = MMFTransformer.Config(modalities=modalities_config, num_labels=2)
472+
mmft = build_model(config)
473+
474+
sample_list = SampleList()
475+
sample_list.image = torch.rand((2, 3, 8, 224, 224))
476+
sample_list.text = torch.randint(0, 512, (2, 128))
477+
478+
transformer_input = mmft.preprocess_sample(sample_list)
479+
input_ids = transformer_input["input_ids"]
480+
self.assertEqual(input_ids["image"].dim(), 3)
481+
self.assertEqual(list(input_ids["image"].size()), [2, 1, 12545])
482+
483+
self.assertEqual(input_ids["text"].dim(), 2)
484+
self.assertEqual(list(input_ids["text"].size()), [2, 128])
485+
486+
position_ids = transformer_input["position_ids"]
487+
test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]]))
488+
test_utils.compare_tensors(
489+
position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))
490+
)
491+
492+
masks = transformer_input["masks"]
493+
test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
494+
test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())
495+
496+
segment_ids = transformer_input["segment_ids"]
497+
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
498+
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())
499+
447500
def test_tie_mlm_head_weight_to_encoder(self):
448501
self._text_modality_config = MMFTransformerModalityConfig(
449502
type="text",

tests/modules/test_encoders.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,10 @@ def test_torchvision_slowfast_r50_encoder(self):
117117
self.assertEqual(output.size(1), 2304)
118118

119119
@skip_if_no_pytorchvideo
120-
def test_torchvision_mvit_encoder(self):
120+
def test_mvit_encoder(self):
121121
config = OmegaConf.create(
122122
{
123-
"name": "torchvideo",
123+
"name": "mvit",
124124
"model_name": "multiscale_vision_transformers",
125125
"random_init": True,
126126
"cls_layer_num": 0,
@@ -129,7 +129,7 @@ def test_torchvision_mvit_encoder(self):
129129
"head": None,
130130
}
131131
)
132-
encoder = encoders.TorchVideoEncoder(config)
132+
encoder = encoders.MViTEncoder(config)
133133
x = torch.rand((1, 3, 8, 224, 224))
134134
output = encoder(x)
135-
self.assertEqual(output.shape, torch.Size([1, 12545, 96]))
135+
self.assertEqual(output.shape, torch.Size([1, 1, 12545]))
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
---
2+
id: pytorchvideo
3+
title: Using Pytorchvideo
4+
sidebar_label: Using Pytorchvideo
5+
---
6+
7+
MMF is integrating with Pytorchvideo!
8+
9+
This means you'll be able to use Pytorchvideo models, datasets, and transforms in multimodal models from MMF.
10+
Pytorch datasets and transforms are coming soon!
11+
12+
If you find PyTorchVideo useful in your work, please use the following BibTeX entry for citation.
13+
```
14+
@inproceedings{fan2021pytorchvideo,
15+
author = {Haoqi Fan and Tullie Murrell and Heng Wang and Kalyan Vasudev Alwala and Yanghao Li and Yilei Li and Bo Xiong and Nikhila Ravi and Meng Li and Haichuan Yang and Jitendra Malik and Ross Girshick and Matt Feiszli and Aaron Adcock and Wan-Yen Lo and Christoph Feichtenhofer},
16+
title = {{PyTorchVideo}: A Deep Learning Library for Video Understanding},
17+
booktitle = {Proceedings of the 29th ACM International Conference on Multimedia},
18+
year = {2021},
19+
note = {\url{https://pytorchvideo.org/}},
20+
}
21+
```
22+
23+
## Setup
24+
25+
In order to use pytorchvideo in MMF you need pytorchvideo installed.
26+
You can install pytorchvideo by running
27+
```
28+
pip install pytorchvideo
29+
```
30+
For detailed instructions consult https://github.com/facebookresearch/pytorchvideo/blob/main/INSTALL.md
31+
32+
33+
## Using Pytorchvideo Models in MMF
34+
35+
Currently Pytorchvideo models are supported as MMF encoders.
36+
To use a Pytorchvideo model as the image encoder for your multimodal model,
37+
use the MMF TorchVideoEncoder or write your own encoder that uses pytorchvideo directly.
38+
39+
The TorchVideoEncoder class is a wrapper around pytorchvideo models.
40+
To instantiate a pytorchvideo model as an encoder you can do,
41+
42+
```python
43+
from mmf.modules import encoders
44+
from omegaconf import OmegaConfg
45+
46+
config = OmegaConf.create(
47+
{
48+
"name": "torchvideo",
49+
"model_name": "slowfast_r50",
50+
"random_init": True,
51+
"cls_layer_num": 1,
52+
}
53+
)
54+
encoder = encoders.TorchVideoEncoder(config)
55+
56+
# some video input
57+
fast = torch.rand((1, 3, 32, 224, 224))
58+
slow = torch.rand((1, 3, 8, 224, 224))
59+
output = encoder([slow, fast])
60+
```
61+
62+
In our config object, we specify that we want to build the `torchvideo` (name) encoder,
63+
that we want to use the pytorchvideo model `slowfast_r50` (model_name),
64+
without pretrained weights (`random_init: True`),
65+
and that we want to remove the last module of the network (the transformer head) (`cls_layer_num: 1`) to just get the hidden state.
66+
This part depends on which model you're using and what you need it for.
67+
68+
This encoder is usually configured from yaml through your model_config yaml.
69+
70+
71+
Suppose we want to use MViT as our image encoder and we only want the first hidden state.
72+
As the MViT model in pytorchvideo returns hidden states in format (batch size, feature dim, num features),
73+
we want to permute the tensor and take the first feature.
74+
To do this we can write our own encoder class in encoders.py
75+
76+
```python
77+
@registry.register_encoder("mvit")
78+
class MViTEncoder(Encoder):
79+
"""
80+
MVIT from pytorchvideo
81+
"""
82+
@dataclass
83+
class Config(Encoder.Config):
84+
name: str = "mvit"
85+
random_init: bool = False
86+
model_name: str = "multiscale_vision_transformers"
87+
spatial_size: int = 224
88+
temporal_size: int = 8
89+
head: Optional[Any] = None
90+
91+
def __init__(self, config: Config):
92+
super().__init__()
93+
self.encoder = TorchVideoEncoder(config)
94+
95+
def forward(self, *args, **kwargs):
96+
output = self.encoder(*args, **kwargs)
97+
output = output.permute(0, 2, 1)
98+
return output[:, :1, :]
99+
```
100+
101+
Here we use the TorchVideoEncoder class to make our MViT model and transform the output to match what we need from an encoder.

0 commit comments

Comments
 (0)