-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Removes separate VISSL caching and adds file_name to torch.hub.load_state_dict_from_url everywhere #179
Conversation
…_state_dict_from_url
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #179 +/- ##
==========================================
- Coverage 76.30% 76.26% -0.04%
==========================================
Files 40 40
Lines 2055 2056 +1
Branches 262 263 +1
==========================================
Hits 1568 1568
Misses 402 402
- Partials 85 86 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@jonasd4 Could you provide a meaningful description of the PR? It doesn't have to be long. Can be a one-liner. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Docstring and filename extension clarification requested
@@ -350,15 +350,14 @@ def __init__( | |||
device=device, | |||
) | |||
|
|||
def _download_and_save_model(self, model_url: str, | |||
output_model_filepath: str, unique_model_id: str): | |||
def _load_vissl_state_dict(self, model_url: str, unique_model_filename: str): | |||
""" | |||
Downloads the model in vissl format, converts it to torchvision format and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe adapt docstring and write that load_state_dict_from_url
is using a cached version if available.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I adapted the docstring, please have a look if that's clear now!
@@ -394,25 +392,25 @@ def load_model_from_source(self) -> None: | |||
Otherwise, loads it from the cache directory. | |||
""" | |||
if self.model_name in SSLExtractor.MODELS: | |||
|
|||
# unique model id name for all models | |||
unique_model_filename = f'thingsvision_ssl_v0_{self.model_name}' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unique_model_filename
no longer has an extension because it is not set in _load_vissl_state_dict
nor in load_state_dict_from_url
. Is this the desired behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is better with an extension! Added it now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jonasd4 LGTM but please add a short description before I approve.
Added the description! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR fixes further issues with
torch.hub.load_state_dict_from_url
torch.hub.load_state_dict_from_url
does not perform any caching)barlowtwins
andvicreg
models as in their respectivehubconf.py
also does not set thefilename
attribute and the filename in the url was the same. (e.g. https://github.com/facebookresearch/barlowtwins/blob/main/hubconf.py). These models are now directly loaded from the respective url while providing thefilename
attribute in the load function.