Skip to content

Commit 68add70

Browse files
[feat] Add pytorchvideo encoder wrapper (#1156)
Summary: Pull Request resolved: #1156 Add an encoder class that constructs any pytorchvideo model from config, and uses this model for its forward pass. Can load pretrained or random init models, based on config. Test Plan: Tested through unit tests on slowfast50 and mvit. Will be tested end-to-end when datasets and transformers are available in mmf ``` (torchvideo) ryanjiang@learnfair5083:~/copy/mmf$ pytest tests/models/test_mmf_transformer.py ================================================== test session starts ================================================== platform linux -- Python 3.7.11, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 rootdir: /private/home/ryanjiang/copy/mmf plugins: forked-1.3.0, timeout-1.4.2, hydra-core-1.1.1, xdist-2.4.0, dash-2.0.0 collected 15 items tests/models/test_mmf_transformer.py ............... [100%] (torchvideo) ryanjiang@learnfair5083:~/copy/mmf$ pytest tests/modules/test_encoders.py ================================================== test session starts ================================================== platform linux -- Python 3.7.11, pytest-6.2.5, py-1.10.0, pluggy-1.0.0 rootdir: /private/home/ryanjiang/copy/mmf plugins: forked-1.3.0, timeout-1.4.2, hydra-core-1.1.1, xdist-2.4.0, dash-2.0.0 collected 12 items tests/modules/test_encoders.py ............ [100%] ``` Reviewed By: apsdehal Differential Revision: D32631207 Pulled By: Ryan-Qiyu-Jiang fbshipit-source-id: 6b549162f7ae9ccea162563e48ed910618a6da54
1 parent ee19bd9 commit 68add70

File tree

4 files changed

+217
-4
lines changed

4 files changed

+217
-4
lines changed

mmf/modules/encoders.py

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import importlib
3+
import logging
24
import os
35
import pickle
46
import re
57
from collections import OrderedDict
68
from copy import deepcopy
7-
from dataclasses import dataclass
9+
from dataclasses import asdict, dataclass
810
from enum import Enum
911
from typing import Any
1012

@@ -25,13 +27,15 @@
2527
from transformers.configuration_auto import AutoConfig
2628
from transformers.modeling_auto import AutoModel
2729

28-
2930
try:
3031
from detectron2.modeling import ShapeSpec, build_resnet_backbone
3132
except ImportError:
3233
pass
3334

3435

36+
logger = logging.getLogger()
37+
38+
3539
class Encoder(nn.Module):
3640
@dataclass
3741
class Config:
@@ -688,6 +692,89 @@ def forward(self, x: Tensor) -> Tensor:
688692
return out
689693

690694

695+
@registry.register_encoder("pytorchvideo")
696+
class PytorchVideoEncoder(Encoder):
697+
"""A thin wrapper around pytorchvideo models.
698+
This class is responsible for integrating pytorchvideo models as encoders.
699+
THis class attempts to construct a pytorchvideo model from torch hub.
700+
If this fails for a random weight model, and pytorchvideo package is available,
701+
build the model with random weights from pytorchvideo.models.
702+
703+
Config:
704+
name (str): Always 'pytorchvideo' Used for builder_encoder()
705+
random_init (bool): Flag to load pretrained weights
706+
model_name (str): Name of the pytorchvideo model to use
707+
drop_last_n_layers (int):
708+
<=0 value for the number of layers to drop off the end
709+
pooler_name (str): Name of pooler used on model output
710+
711+
Raises:
712+
ImportError:
713+
The constructor raises an ImportError if pytorchvideo is not installed.
714+
"""
715+
716+
@dataclass
717+
class Config(Encoder.Config):
718+
name: str = "pytorchvideo"
719+
random_init: bool = False
720+
model_name: str = "slowfast_r50"
721+
drop_last_n_layers: int = -1
722+
pooler_name: str = "identity"
723+
724+
PYTORCHVIDEO_REPO = "facebookresearch/pytorchvideo:main"
725+
726+
def __init__(self, config: Config):
727+
super().__init__()
728+
config = OmegaConf.create({**asdict(self.Config()), **config})
729+
if config.random_init:
730+
params = dict(**OmegaConf.to_container(config))
731+
params = {
732+
k: v
733+
for k, v in params.items()
734+
if k not in PytorchVideoEncoder.Config().__dict__
735+
}
736+
try:
737+
model = torch.hub.load(
738+
PytorchVideoEncoder.PYTORCHVIDEO_REPO,
739+
model=config.model_name,
740+
pretrained=False,
741+
**params,
742+
)
743+
except BaseException as err:
744+
pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
745+
if pytorchvideo_spec is None:
746+
raise err
747+
import pytorchvideo.models.hub as hub
748+
749+
model_create_fn = getattr(hub, config.model_name)
750+
model = model_create_fn(pretrained=False, **params)
751+
else:
752+
# load weights from TorchHub
753+
model = torch.hub.load(
754+
PytorchVideoEncoder.PYTORCHVIDEO_REPO,
755+
model=config.model_name,
756+
pretrained=True,
757+
)
758+
encoder_list = []
759+
if config.drop_last_n_layers == 0:
760+
encoder_list += [model]
761+
else:
762+
modules_list = list(model.children())
763+
if len(modules_list) == 1:
764+
modules_list = list(modules_list[0].children())
765+
modules = modules_list[: config.drop_last_n_layers]
766+
encoder_list += modules
767+
768+
pooler = registry.get_pool_class(config.pooler_name)()
769+
encoder_list += [pooler]
770+
self.encoder = nn.Sequential(*encoder_list)
771+
772+
def forward(self, *args, **kwargs):
773+
# pass along input to model
774+
# assumes caller obeys the dynamic model signature
775+
return self.encoder(*args, **kwargs)
776+
777+
691778
@registry.register_encoder("r2plus1d_18")
692779
class R2Plus1D18VideoEncoder(PooledEncoder):
693780
"""

tests/models/test_mmf_transformer.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
from mmf.utils.configuration import Configuration
2222
from mmf.utils.env import setup_imports, teardown_imports
2323
from omegaconf import OmegaConf
24-
24+
from tests.test_utils import (
25+
skip_if_no_pytorchvideo,
26+
)
2527

2628
BERT_VOCAB_SIZE = 30255
2729
ROBERTA_VOCAB_SIZE = 50265
@@ -444,6 +446,63 @@ def test_preprocessing_with_resnet_encoder(self):
444446
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
445447
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())
446448

449+
@skip_if_no_pytorchvideo
450+
def test_preprocessing_with_mvit_encoder(self):
451+
encoder_config = OmegaConf.create(
452+
{
453+
"name": "pytorchvideo",
454+
"model_name": "mvit_base_32x3",
455+
"random_init": True,
456+
"drop_last_n_layers": 0,
457+
"pooler_name": "cls",
458+
"spatial_size": 224,
459+
"temporal_size": 8,
460+
"head": None,
461+
"embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
462+
"atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
463+
"pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
464+
"pool_kv_stride_adaptive": [1, 8, 8],
465+
"pool_kvq_kernel": [3, 3, 3],
466+
}
467+
)
468+
self._image_modality_config = MMFTransformerModalityConfig(
469+
type="image",
470+
key="image",
471+
embedding_dim=768,
472+
position_dim=1,
473+
segment_id=0,
474+
encoder=encoder_config,
475+
)
476+
modalities_config = [self._image_modality_config, self._text_modality_config]
477+
config = MMFTransformer.Config(modalities=modalities_config, num_labels=2)
478+
mmft = build_model(config)
479+
480+
sample_list = SampleList()
481+
sample_list.image = torch.rand((2, 3, 8, 224, 224))
482+
sample_list.text = torch.randint(0, 512, (2, 128))
483+
484+
transformer_input = mmft.preprocess_sample(sample_list)
485+
input_ids = transformer_input["input_ids"]
486+
self.assertEqual(input_ids["image"].dim(), 3)
487+
self.assertEqual(list(input_ids["image"].size()), [2, 1, 768])
488+
489+
self.assertEqual(input_ids["text"].dim(), 2)
490+
self.assertEqual(list(input_ids["text"].size()), [2, 128])
491+
492+
position_ids = transformer_input["position_ids"]
493+
test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]]))
494+
test_utils.compare_tensors(
495+
position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))
496+
)
497+
498+
masks = transformer_input["masks"]
499+
test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
500+
test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())
501+
502+
segment_ids = transformer_input["segment_ids"]
503+
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
504+
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())
505+
447506
def test_tie_mlm_head_weight_to_encoder(self):
448507
self._text_modality_config = MMFTransformerModalityConfig(
449508
type="text",

tests/modules/test_encoders.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch
77
from mmf.modules import encoders
88
from omegaconf import OmegaConf
9-
from tests.test_utils import setup_proxy, skip_if_old_transformers
9+
from tests.test_utils import (
10+
setup_proxy,
11+
skip_if_old_transformers,
12+
skip_if_no_pytorchvideo,
13+
)
1014
from torch import nn
1115

1216

@@ -102,3 +106,59 @@ def test_vit_encoder(self):
102106
x = torch.rand(32, 197, 768)
103107
output, _ = encoder(x)
104108
self.assertEqual(output.size(-1), config.out_dim)
109+
110+
@skip_if_no_pytorchvideo
111+
def test_pytorchvideo_slowfast_r50_encoder(self):
112+
# instantiate video encoder from pytorchvideo
113+
# default model is slowfast_r50
114+
config = OmegaConf.structured(encoders.PytorchVideoEncoder.Config())
115+
encoder = encoders.PytorchVideoEncoder(config)
116+
fast = torch.rand((1, 3, 32, 224, 224))
117+
slow = torch.rand((1, 3, 8, 224, 224))
118+
output = encoder([slow, fast])
119+
# check output tensor is the expected feature dim size
120+
# (bs, feature_dim)
121+
self.assertEqual(output.size(1), 2304)
122+
123+
@skip_if_no_pytorchvideo
124+
def test_mvit_encoder(self):
125+
config = {
126+
"name": "pytorchvideo",
127+
"model_name": "mvit_base_32x3",
128+
"random_init": True,
129+
"drop_last_n_layers": 0,
130+
"pooler_name": "cls",
131+
"spatial_size": 224,
132+
"temporal_size": 8,
133+
"head": None,
134+
"embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
135+
"atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
136+
"pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
137+
"pool_kv_stride_adaptive": [1, 8, 8],
138+
"pool_kvq_kernel": [3, 3, 3],
139+
}
140+
# test bert cls pooler
141+
encoder = encoders.PytorchVideoEncoder(OmegaConf.create(config))
142+
x = torch.rand((1, 3, 8, 224, 224))
143+
output = encoder(x)
144+
# check output tensor is the expected feature dim size
145+
# based on pooled attention configs
146+
# for more details consult https://arxiv.org/pdf/2104.11227
147+
# and https://github.com/facebookresearch/pytorchvideo/
148+
# (bs, num_features, feature_dim)
149+
self.assertEqual(output.shape, torch.Size([1, 768]))
150+
151+
# test avg pooler
152+
encoder = encoders.PytorchVideoEncoder(
153+
OmegaConf.create(dict(config, pooler_name="avg"))
154+
)
155+
output = encoder(x)
156+
self.assertEqual(output.shape, torch.Size([1, 768]))
157+
158+
# test no pooling
159+
encoder = encoders.PytorchVideoEncoder(
160+
OmegaConf.create(dict(config, pooler_name="identity"))
161+
)
162+
output = encoder(x)
163+
# (bs, num_features, feature_dim)
164+
self.assertEqual(output.shape, torch.Size([1, 197, 768]))

tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ def wrap(testfn, reason="Requires newer version of transformers"):
102102
return wrap
103103

104104

105+
def skip_if_no_pytorchvideo(testfn, reason="Requires pytorchvideo"):
106+
import importlib
107+
108+
pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
109+
return unittest.skipIf(pytorchvideo_spec is None, reason)(testfn)
110+
111+
105112
def compare_state_dicts(a, b):
106113
same = True
107114
same = same and (list(a.keys()) == list(b.keys()))

0 commit comments

Comments
 (0)