Skip to content

Commit

Permalink
Merge pull request #124 from ViCCo-Group/torchtyping
Browse files Browse the repository at this point in the history
added torchtyping
  • Loading branch information
LukasMut authored Jan 1, 2023
2 parents 959a42e + 1a9a3ce commit b4b53dc
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 84 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ torchvision==0.13.*
tqdm==4.64.0
numba==0.56.*
matplotlib==3.5.2
torchtyping
regex
scipy
h5py
tensorflow==2.9.* ; sys_platform != 'darwin' or platform_machine != 'arm64'
tensorflow-macos==2.9.* ; sys_platform == 'darwin' and platform_machine == 'arm64'
timm==0.6.*
open_clip_torch==2.0.2
git+https://github.com/openai/CLIP.git
git+https://github.com/openai/CLIP.git
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"open_clip_torch==2.0.*",
"tqdm==4.64.0",
"timm==0.6.*",
"torchtyping",
"regex",
"scikit-image==0.19.3",
"scikit-learn==1.1.*",
Expand Down Expand Up @@ -50,4 +51,4 @@
entry_points={"console_scripts": ["thingsvision = thingsvision.thingsvision:main"]},
python_requires=">=3.8",
dependency_links=["git+https://github.com/openai/CLIP.git"],
)
)
2 changes: 1 addition & 1 deletion thingsvision/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.2.15"
__version__ = "2.2.17"
4 changes: 3 additions & 1 deletion thingsvision/core/extraction/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
from tqdm import tqdm

Array = np.ndarray


@dataclass(init=True, repr=True)
class BaseExtractor:
Expand Down Expand Up @@ -39,7 +41,7 @@ def extract_features(
flatten_acts: bool,
output_dir: str = None,
step_size: int = None,
) -> np.ndarray:
) -> Array:
"""Extract hidden unit activations (at specified layer) for every image in the database.
Parameters
Expand Down
129 changes: 71 additions & 58 deletions thingsvision/core/extraction/extractor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from dataclasses import dataclass
from typing import Any, Dict

Expand All @@ -7,7 +8,6 @@
import timm
import torch
import torchvision
import os

try:
from torch.hub import load_state_dict_from_url
Expand Down Expand Up @@ -37,13 +37,13 @@
@dataclass(repr=True)
class TorchvisionExtractor(BaseExtractor, PyTorchMixin):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "DEFAULT"},
Expand Down Expand Up @@ -88,12 +88,12 @@ def load_model_from_source(self) -> None:
)

def get_default_transformation(
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Any:
if self.weights:
transforms = self.weights.transforms()
Expand All @@ -108,13 +108,13 @@ def get_default_transformation(
@dataclass(repr=True)
class TimmExtractor(BaseExtractor, PyTorchMixin):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
) -> None:
super().__init__(
model_name=model_name,
Expand All @@ -138,13 +138,13 @@ def load_model_from_source(self) -> None:
@dataclass(repr=True)
class KerasExtractor(BaseExtractor, TensorFlowMixin):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "imagenet"}
Expand Down Expand Up @@ -184,53 +184,53 @@ class SSLExtractor(BaseExtractor, PyTorchMixin):
"simclr-rn50": {
"url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/simclr_rn50_800ep_simclr_8node_resnet_16_07_20.7e8feed1/model_final_checkpoint_phase799.torch",
"arch": "resnet50",
"type": "vissl"
"type": "vissl",
},
"mocov2-rn50": {
"url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/moco_v2_1node_lr.03_step_b32_zero_init/model_final_checkpoint_phase199.torch",
"arch": "resnet50",
"type": "vissl"
"type": "vissl",
},
"jigsaw-rn50": {
"url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/jigsaw_rn50_in1k_ep105_perm2k_jigsaw_8gpu_resnet_17_07_20.db174a43/model_final_checkpoint_phase104.torch",
"arch": "resnet50",
"type": "vissl"
"type": "vissl",
},
"rotnet-rn50": {
"url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/rotnet_rn50_in1k_ep105_rotnet_8gpu_resnet_17_07_20.46bada9f/model_final_checkpoint_phase125.torch",
"arch": "resnet50",
"type": "vissl"
"type": "vissl",
},
"swav-rn50": {
"url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/swav_in1k_rn50_800ep_swav_8node_resnet_27_07_20.a0a6b676/model_final_checkpoint_phase799.torch",
"arch": "resnet50",
"type": "vissl"
"type": "vissl",
},
"pirl-rn50": {
"url": "https://dl.fbaipublicfiles.com/vissl/model_zoo/pirl_jigsaw_4node_pirl_jigsaw_4node_resnet_22_07_20.34377f59/model_final_checkpoint_phase799.torch",
"arch": "resnet50",
"type": "vissl"
"type": "vissl",
},
"barlowtwins-rn50": {
"repository": "facebookresearch/barlowtwins:main",
"arch": "resnet50",
"type": "hub"
"type": "hub",
},
"vicreg-rn50": {
"repository": "facebookresearch/vicreg:main",
"arch": "resnet50",
"type": "hub"
}
"type": "hub",
},
}

def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict = None,
preprocess: Any = None,
) -> None:
super().__init__(
model_name=model_name,
Expand Down Expand Up @@ -260,16 +260,18 @@ def _download_and_save_model(self, model_url: str, output_model_filepath: str):
torch.save(converted_model, output_model_filepath)
return converted_model

def _replace_module_prefix(self, state_dict: Dict[str, Any],
prefix: str,
replace_with: str = ""):
def _replace_module_prefix(
self, state_dict: Dict[str, Any], prefix: str, replace_with: str = ""
):
"""
Remove prefixes in a state_dict needed when loading models that are not VISSL
trained models.
Specify the prefix in the keys that should be removed.
"""
state_dict = {
(key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key): val
(
key.replace(prefix, replace_with, 1) if key.startswith(prefix) else key
): val
for (key, val) in state_dict.items()
}
return state_dict
Expand All @@ -279,9 +281,16 @@ def _get_torch_home(self):
Gets the torch home folder used as a cache directory for the vissl models.
"""
torch_home = os.path.expanduser(
os.getenv(SSLExtractor.ENV_TORCH_HOME,
os.path.join(os.getenv(SSLExtractor.ENV_XDG_CACHE_HOME,
SSLExtractor.DEFAULT_CACHE_DIR), "torch")))
os.getenv(
SSLExtractor.ENV_TORCH_HOME,
os.path.join(
os.getenv(
SSLExtractor.ENV_XDG_CACHE_HOME, SSLExtractor.DEFAULT_CACHE_DIR
),
"torch",
),
)
)
return torch_home

def load_model_from_source(self) -> None:
Expand All @@ -296,20 +305,24 @@ def load_model_from_source(self) -> None:
model_filepath = os.path.join(cache_dir, self.model_name + ".torch")
if not os.path.exists(model_filepath):
os.makedirs(cache_dir, exist_ok=True)
model_state_dict = self._download_and_save_model(model_url=model_config["url"],
output_model_filepath=model_filepath)
model_state_dict = self._download_and_save_model(
model_url=model_config["url"],
output_model_filepath=model_filepath,
)
else:
model_state_dict = torch.load(model_filepath, map_location=torch.device("cpu"))
model_state_dict = torch.load(
model_filepath, map_location=torch.device("cpu")
)
self.model = getattr(torchvision.models, model_config["arch"])()
self.model.fc = torch.nn.Identity()
self.model.load_state_dict(model_state_dict, strict=True)
elif model_config["type"] == "hub":
self.model = torch.hub.load(model_config["repository"], model_config["arch"])
self.model = torch.hub.load(
model_config["repository"], model_config["arch"]
)
self.model.fc = torch.nn.Identity()
else:
raise ValueError(
f"\nUnknown model type.\n"
)
raise ValueError(f"\nUnknown model type.\n")
else:
raise ValueError(
f"\nCould not find {self.model_name} in the SSLExtractor.\n"
Expand Down
31 changes: 22 additions & 9 deletions thingsvision/core/extraction/helpers.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Union
from warnings import warn

import numpy as np
import torch

import thingsvision.custom_models as custom_models
import thingsvision.custom_models.cornet as cornet
import torch
from torchtyping import TensorType

from .base import BaseExtractor
from .extractor import KerasExtractor, TimmExtractor, TorchvisionExtractor, SSLExtractor
from .extractor import (KerasExtractor, SSLExtractor, TimmExtractor,
TorchvisionExtractor)
from .mixin import PyTorchMixin, TensorFlowMixin

Tensor = torch.Tensor
Array = np.ndarray
AxisError = np.AxisError

Expand Down Expand Up @@ -60,12 +60,19 @@ def show_model(self):
print("visual")

@staticmethod
def forward(batch: Tensor) -> Tensor:
def forward(batch: TensorType["b", "c", "h", "w"]) -> TensorType["b", "d"]:
img_features = model.encode_image(batch)
return img_features

@staticmethod
def flatten_acts(act: Tensor, batch: Tensor, module_name: str) -> Tensor:
def flatten_acts(
act: Union[
TensorType["b", "num_maps", "h_prime", "w_prime"],
TensorType["b", "t", "d"],
],
batch: TensorType["b", "c", "h", "w"],
module_name: str,
) -> TensorType["b", "p"]:
if module_name.endswith("attn"):
if isinstance(act, tuple):
act = act[0]
Expand All @@ -81,7 +88,9 @@ def flatten_acts(act: Tensor, batch: Tensor, module_name: str) -> Tensor:

if model_name == "OpenCLIP":

def forward(self, batch: Tensor) -> Tensor:
def forward(
self, batch: TensorType["b", "c", "h", "w"]
) -> TensorType["b", "d"]:
return self.model(batch, text=None)

CustomExtractor.forward = forward
Expand Down Expand Up @@ -179,7 +188,11 @@ def get_extractor(
elif source == "ssl":
return SSLExtractor(**model_args)
elif source == "vissl":
warn('The source "vissl" is deprecated. Use the source "ssl" instead.', DeprecationWarning, stacklevel=2)
warn(
'The source "vissl" is deprecated. Use the source "ssl" instead.',
DeprecationWarning,
stacklevel=2,
)
return SSLExtractor(**model_args)
else:
raise ValueError(
Expand Down
Loading

0 comments on commit b4b53dc

Please sign in to comment.