-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable passing output_hidden_states
#731
base: main
Are you sure you want to change the base?
Conversation
Update coca_model.py
# Conflicts: # src/open_clip/coca_model.py # src/open_clip/model.py # src/open_clip/transformer.py
@thepowerfuldeez thanks for the PR, will say that we need to do this one carefully as it impacts the output interface. I recognize that people want this, but it's been slow to be added because it's a bit of a mess when you consider all the details. First things first, I feel we should only allow this if dictionary output is enabled, having too many tuple variations as possible outputs is asking for trouble. Next, the internal typing has gotchas with torchscript when you alternate between Tuple and tensor outputs. Not quite sure what the needed combination of typing would be to have that pass. |
Hi @rwightman ! I’m on the same page with you, are it should be supported as a dict output. I couldn’t decide how to better use it considering dict output appears only in 1 place, where I needed to have output of VisionTransformer to output hidden states (I am not using CLIP class and hence implemented logic with setting attribute for transformer classes). |
Hi @thepowerfuldeez , thanks a lot for this PR, it has been really useful. However, I have some doubts when using the
But then, using that (1024,) sized CLS embedding is not resulting in good classification metrics for my task (~50% accuracy), while with the Thanks! |
Related to #657
Inspired by PR above, I made PR without breaking backward compatibility. In addition, I made support for passing output_hidden_states as attribute for VisionTransformer and TextTransformer classes.
Example: