Skip to content

Commit 81e2e43

Browse files
[feat] Add pytorchvideo encoder wrapper
Add an encoder class that constructs any pytorchvideo model from config, and uses this model for its forward pass. Can load pretrained or random init models, based on config. ghstack-source-id: 2eb33e0 Pull Request resolved: #1156
1 parent b6a5804 commit 81e2e43

File tree

4 files changed

+128
-3
lines changed

4 files changed

+128
-3
lines changed

mmf/modules/encoders.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import importlib
3+
import inspect
4+
import logging
25
import os
36
import pickle
47
import re
58
from collections import OrderedDict
69
from copy import deepcopy
7-
from dataclasses import dataclass
10+
from dataclasses import asdict, dataclass
811
from enum import Enum
912
from typing import Any
1013

@@ -25,13 +28,15 @@
2528
from transformers.configuration_auto import AutoConfig
2629
from transformers.modeling_auto import AutoModel
2730

28-
2931
try:
3032
from detectron2.modeling import ShapeSpec, build_resnet_backbone
3133
except ImportError:
3234
pass
3335

3436

37+
logger = logging.getLogger()
38+
39+
3540
class Encoder(nn.Module):
3641
@dataclass
3742
class Config:
@@ -688,6 +693,86 @@ def forward(self, x: Tensor) -> Tensor:
688693
return out
689694

690695

696+
@registry.register_encoder("torchvideo")
697+
class TorchVideoEncoder(Encoder):
698+
"""
699+
Wrapper around importing torchvideo models
700+
as encoders.
701+
"""
702+
703+
@dataclass
704+
class Config(Encoder.Config):
705+
name: str = "torchvideo"
706+
random_init: bool = False
707+
model_name: str = "slowfast_r50"
708+
cls_layer_num: int = 1
709+
710+
def __init__(self, config: Config):
711+
pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
712+
if pytorchvideo_spec is None:
713+
raise ImportError("pytorchvideo required for using TorchVideoEncoder")
714+
import pytorchvideo.models as models
715+
716+
super().__init__()
717+
config = OmegaConf.create({**asdict(self.Config()), **config})
718+
if config.random_init:
719+
model_create_fn_name = f"create_{config.model_name}"
720+
model_create_fn = getattr(models, model_create_fn_name)
721+
params = dict(**config)
722+
params.pop("random_init")
723+
params.pop("model_name")
724+
params.pop("cls_layer_num")
725+
726+
accepted_params, ignored_params = self.filter_dict_to_signature(
727+
model_create_fn, params
728+
)
729+
if ignored_params:
730+
ignored_params_str = " ".join(ignored_params.keys())
731+
logger.warning(
732+
"The following model constructor params were ignored"
733+
+ " because they don't match a named param in the constructor: "
734+
+ ignored_params_str
735+
)
736+
model = model_create_fn(**accepted_params)
737+
else:
738+
# load weights from TorchHub
739+
model = torch.hub.load(
740+
"facebookresearch/pytorchvideo:main",
741+
model=config.model_name,
742+
pretrained=True,
743+
)
744+
745+
if config.cls_layer_num == 0:
746+
self.encoder = model
747+
return
748+
749+
modules_list = list(model.children())
750+
if len(modules_list) == 1:
751+
modules_list = list(modules_list[0].children())
752+
modules = modules_list[: -config.cls_layer_num]
753+
self.encoder = nn.Sequential(*modules)
754+
755+
def forward(self, *args, **kwargs):
756+
# pass along input to model
757+
# assumes caller obeys the dynamic model signature
758+
return self.encoder(*args, **kwargs)
759+
760+
def filter_dict_to_signature(self, callable, params):
761+
constructor_signature = inspect.signature(callable) # Signature obj
762+
constructor_params = constructor_signature.parameters
763+
accepted_params = {
764+
param_name: params[param_name]
765+
for param_name in params
766+
if param_name in constructor_params
767+
}
768+
ignored_params = {
769+
param_name: params[param_name]
770+
for param_name in params
771+
if param_name not in constructor_params
772+
}
773+
return accepted_params, ignored_params
774+
775+
691776
@registry.register_encoder("r2plus1d_18")
692777
class R2Plus1D18VideoEncoder(PooledEncoder):
693778
"""

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@ pytorch-lightning @ git+https://github.com/PyTorchLightning/pytorch-lightning@fa
2222
torchaudio>=0.6.0, <=0.9.0
2323
psutil
2424
pillow==8.3.1
25+
av>=8.0.3
26+
pytorchvideo>=0.1.3

tests/modules/test_encoders.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import torch
77
from mmf.modules import encoders
88
from omegaconf import OmegaConf
9-
from tests.test_utils import setup_proxy, skip_if_old_transformers
9+
from tests.test_utils import (
10+
setup_proxy,
11+
skip_if_old_transformers,
12+
skip_if_no_pytorchvideo,
13+
)
1014
from torch import nn
1115

1216

@@ -102,3 +106,30 @@ def test_vit_encoder(self):
102106
x = torch.rand(32, 197, 768)
103107
output, _ = encoder(x)
104108
self.assertEqual(output.size(-1), config.out_dim)
109+
110+
@skip_if_no_pytorchvideo
111+
def test_torchvision_slowfast_r50_encoder(self):
112+
config = OmegaConf.structured(encoders.TorchVideoEncoder.Config())
113+
encoder = encoders.TorchVideoEncoder(config)
114+
fast = torch.rand((1, 3, 32, 224, 224))
115+
slow = torch.rand((1, 3, 8, 224, 224))
116+
output = encoder([slow, fast])
117+
self.assertEqual(output.size(1), 2304)
118+
119+
@skip_if_no_pytorchvideo
120+
def test_torchvision_mvit_encoder(self):
121+
config = OmegaConf.create(
122+
{
123+
"name": "torchvideo",
124+
"model_name": "multiscale_vision_transformers",
125+
"random_init": True,
126+
"cls_layer_num": 0,
127+
"spatial_size": 224,
128+
"temporal_size": 8,
129+
"head": None,
130+
}
131+
)
132+
encoder = encoders.TorchVideoEncoder(config)
133+
x = torch.rand((1, 3, 8, 224, 224))
134+
output = encoder(x)
135+
self.assertEqual(output.shape, torch.Size([1, 12545, 96]))

tests/test_utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ def wrap(testfn, reason="Requires newer version of transformers"):
102102
return wrap
103103

104104

105+
def skip_if_no_pytorchvideo(testfn, reason="Requires pytorchvideo"):
106+
import importlib
107+
108+
pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
109+
return unittest.skipUnless(pytorchvideo_spec is not None, reason)(testfn)
110+
111+
105112
def compare_state_dicts(a, b):
106113
same = True
107114
same = same and (list(a.keys()) == list(b.keys()))

0 commit comments

Comments
 (0)