Skip to content

Commit

Permalink
merge: pull in mlfoundations#523 for text locking
Browse files Browse the repository at this point in the history
  • Loading branch information
Interpause committed May 23, 2024
2 parents 2e8de83 + 6563d9d commit 7acf1fd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .modified_resnet import ModifiedResNet
from .timm_model import TimmModel
from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\
text_global_pool
text_global_pool, lock_text_transformer
from .utils import to_2tuple


Expand Down Expand Up @@ -257,6 +257,9 @@ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)

def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
lock_text_transformer(self, unlocked_layers, freeze_layer_norm)

@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.visual.set_grad_checkpointing(enable)
Expand Down
33 changes: 33 additions & 0 deletions src/open_clip/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ def __init__(

self.init_parameters()

def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
lock_text_transformer(self, unlocked_layers, freeze_layer_norm)

def init_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
Expand Down Expand Up @@ -708,6 +711,36 @@ def forward(self, text):
return pooled


def lock_text_transformer(
transformer: TextTransformer, unlocked_layers: int = 0, freeze_layer_norm: bool = True
):
groups = [
[transformer.token_embedding, transformer.positional_embedding],
*transformer.transformer.resblocks[:-1],
[transformer.transformer.resblocks[ -1], transformer.ln_final],
transformer.text_projection,
]

def _freeze(modules, freeze_layer_norm: bool = True):
for module in modules:
# `CLIP.text_projection` and `CLIP.positional_embedding`
if isinstance(module, nn.Parameter):
module.requires_grad = False

# All other modules
elif isinstance(module, nn.Module):
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False

else:
raise TypeError(f"Encountered unexpected module type {type(module)} for module {module}")

if (not unlocked_layers) or (unlocked_layers == 0): # full freezing
_freeze(groups, freeze_layer_norm)
else:
_freeze(groups[:-unlocked_layers], freeze_layer_norm)


class MultimodalTransformer(Transformer):
def __init__(
self,
Expand Down

0 comments on commit 7acf1fd

Please sign in to comment.