diff --git a/torchtune/modules/vision_transformer.py b/torchtune/modules/vision_transformer.py index 6f261514b6..b203c3f59d 100644 --- a/torchtune/modules/vision_transformer.py +++ b/torchtune/modules/vision_transformer.py @@ -381,6 +381,9 @@ def forward( h = x.reshape(bsz, n_imgs, n_tiles, n_tokens, embed_dim) hidden_states.append(h) x = transformer_layer(x) + if layer_idx in self.out_indices: + h = x.reshape(bsz, n_imgs, n_tiles, n_tokens, embed_dim) + hidden_states.append(h) # norm x = self.ln_post(x)