|
1 | 1 | # Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 | +import importlib |
| 3 | +import inspect |
| 4 | +import logging |
2 | 5 | import os
|
3 | 6 | import pickle
|
4 | 7 | import re
|
5 | 8 | from collections import OrderedDict
|
6 | 9 | from copy import deepcopy
|
7 |
| -from dataclasses import dataclass |
| 10 | +from dataclasses import asdict, dataclass |
8 | 11 | from enum import Enum
|
9 | 12 | from typing import Any
|
10 | 13 |
|
|
25 | 28 | from transformers.configuration_auto import AutoConfig
|
26 | 29 | from transformers.modeling_auto import AutoModel
|
27 | 30 |
|
28 |
| - |
29 | 31 | try:
|
30 | 32 | from detectron2.modeling import ShapeSpec, build_resnet_backbone
|
31 | 33 | except ImportError:
|
32 | 34 | pass
|
33 | 35 |
|
34 | 36 |
|
| 37 | +logger = logging.getLogger() |
| 38 | + |
| 39 | + |
35 | 40 | class Encoder(nn.Module):
|
36 | 41 | @dataclass
|
37 | 42 | class Config:
|
@@ -688,6 +693,86 @@ def forward(self, x: Tensor) -> Tensor:
|
688 | 693 | return out
|
689 | 694 |
|
690 | 695 |
|
| 696 | +@registry.register_encoder("torchvideo") |
| 697 | +class TorchVideoEncoder(Encoder): |
| 698 | + """ |
| 699 | + Wrapper around importing torchvideo models |
| 700 | + as encoders. |
| 701 | + """ |
| 702 | + |
| 703 | + @dataclass |
| 704 | + class Config(Encoder.Config): |
| 705 | + name: str = "torchvideo" |
| 706 | + random_init: bool = False |
| 707 | + model_name: str = "slowfast_r50" |
| 708 | + cls_layer_num: int = 1 |
| 709 | + |
| 710 | + def __init__(self, config: Config): |
| 711 | + pytorchvideo_spec = importlib.util.find_spec("pytorchvideo") |
| 712 | + if pytorchvideo_spec is None: |
| 713 | + raise ImportError("pytorchvideo required for using TorchVideoEncoder") |
| 714 | + import pytorchvideo.models as models |
| 715 | + |
| 716 | + super().__init__() |
| 717 | + config = OmegaConf.create({**asdict(self.Config()), **config}) |
| 718 | + if config.random_init: |
| 719 | + model_create_fn_name = f"create_{config.model_name}" |
| 720 | + model_create_fn = getattr(models, model_create_fn_name) |
| 721 | + params = dict(**config) |
| 722 | + params.pop("random_init") |
| 723 | + params.pop("model_name") |
| 724 | + params.pop("cls_layer_num") |
| 725 | + |
| 726 | + accepted_params, ignored_params = self.filter_dict_to_signature( |
| 727 | + model_create_fn, params |
| 728 | + ) |
| 729 | + if ignored_params: |
| 730 | + ignored_params_str = " ".join(ignored_params.keys()) |
| 731 | + logger.warning( |
| 732 | + "The following model constructor params were ignored" |
| 733 | + + " because they don't match a named param in the constructor: " |
| 734 | + + ignored_params_str |
| 735 | + ) |
| 736 | + model = model_create_fn(**accepted_params) |
| 737 | + else: |
| 738 | + # load weights from TorchHub |
| 739 | + model = torch.hub.load( |
| 740 | + "facebookresearch/pytorchvideo:main", |
| 741 | + model=config.model_name, |
| 742 | + pretrained=True, |
| 743 | + ) |
| 744 | + |
| 745 | + if config.cls_layer_num == 0: |
| 746 | + self.encoder = model |
| 747 | + return |
| 748 | + |
| 749 | + modules_list = list(model.children()) |
| 750 | + if len(modules_list) == 1: |
| 751 | + modules_list = list(modules_list[0].children()) |
| 752 | + modules = modules_list[: -config.cls_layer_num] |
| 753 | + self.encoder = nn.Sequential(*modules) |
| 754 | + |
| 755 | + def forward(self, *args, **kwargs): |
| 756 | + # pass along input to model |
| 757 | + # assumes caller obeys the dynamic model signature |
| 758 | + return self.encoder(*args, **kwargs) |
| 759 | + |
| 760 | + def filter_dict_to_signature(self, callable, params): |
| 761 | + constructor_signature = inspect.signature(callable) # Signature obj |
| 762 | + constructor_params = constructor_signature.parameters |
| 763 | + accepted_params = { |
| 764 | + param_name: params[param_name] |
| 765 | + for param_name in params |
| 766 | + if param_name in constructor_params |
| 767 | + } |
| 768 | + ignored_params = { |
| 769 | + param_name: params[param_name] |
| 770 | + for param_name in params |
| 771 | + if param_name not in constructor_params |
| 772 | + } |
| 773 | + return accepted_params, ignored_params |
| 774 | + |
| 775 | + |
691 | 776 | @registry.register_encoder("r2plus1d_18")
|
692 | 777 | class R2Plus1D18VideoEncoder(PooledEncoder):
|
693 | 778 | """
|
|
0 commit comments