Skip to content

Commit

Permalink
Merge pull request #178 from ViCCo-Group/fix/ssl-models-cache
Browse files Browse the repository at this point in the history
Fixed same weights bug when loading SSL models from Vissl
  • Loading branch information
LukasMut authored Sep 9, 2024
2 parents ad903f3 + e2a7892 commit 32b1a96
Showing 1 changed file with 47 additions and 46 deletions.
93 changes: 47 additions & 46 deletions thingsvision/core/extraction/extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import torch
import torchvision

try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torch.hub import load_state_dict_from_url

from thingsvision.utils.checkpointing import get_torch_home
from thingsvision.utils.models.dino import vit_base, vit_small, vit_tiny
Expand Down Expand Up @@ -45,13 +42,13 @@

class TorchvisionExtractor(PyTorchExtractor):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "DEFAULT"}
Expand Down Expand Up @@ -96,12 +93,12 @@ def load_model_from_source(self) -> None:
)

def get_default_transformation(
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Any:
if self.weights:
warnings.warn(
Expand All @@ -120,13 +117,13 @@ def get_default_transformation(

class TimmExtractor(PyTorchExtractor):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(
model_name=model_name,
Expand All @@ -147,12 +144,12 @@ def load_model_from_source(self) -> None:
)

def get_default_transformation(
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
self,
mean,
std,
resize_dim: int = 256,
crop_dim: int = 224,
apply_center_crop: bool = True,
) -> Any:
warnings.warn(
message="\nInput arguments are ignored because <timm> automatically infers transforms from model config.\n",
Expand All @@ -167,13 +164,13 @@ def get_default_transformation(

class KerasExtractor(TensorFlowExtractor):
def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
) -> None:
model_parameters = (
model_parameters if model_parameters else {"weights": "imagenet"}
Expand Down Expand Up @@ -336,13 +333,13 @@ class SSLExtractor(PyTorchExtractor):
}

def __init__(
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
self,
model_name: str,
pretrained: bool,
device: str,
model_path: str = None,
model_parameters: Dict[str, Union[str, bool, List[str]]] = None,
preprocess: Optional[Callable] = None,
) -> None:
super().__init__(
model_name=model_name,
Expand All @@ -353,12 +350,15 @@ def __init__(
device=device,
)

def _download_and_save_model(self, model_url: str, output_model_filepath: str):
def _download_and_save_model(self, model_url: str,
output_model_filepath: str, unique_model_id: str):
"""
Downloads the model in vissl format, converts it to torchvision format and
saves it under output_model_filepath.
"""
model = load_state_dict_from_url(model_url, map_location=torch.device("cpu"))
model = load_state_dict_from_url(model_url,
map_location=torch.device("cpu"),
file_name=f'{unique_model_id}.pt')

# get the model trunk to rename
if "classy_state_dict" in model.keys():
Expand Down Expand Up @@ -403,6 +403,7 @@ def load_model_from_source(self) -> None:
model_state_dict = self._download_and_save_model(
model_url=model_config["url"],
output_model_filepath=model_filepath,
unique_model_id=f'thingsvision_vissl_{self.model_name}'
)
else:
model_state_dict = torch.load(
Expand Down

0 comments on commit 32b1a96

Please sign in to comment.