Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions clip_benchmark/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@
import torch
from .open_clip import load_open_clip
from .japanese_clip import load_japanese_clip
from .synthclip import load_synthclip
from .scaling import load_model

# loading function must return (model, transform, tokenizer)
TYPE2FUNC = {
"open_clip": load_open_clip,
"ja_clip": load_japanese_clip
"ja_clip": load_japanese_clip,
"synthclip": load_synthclip,
"scaling": load_model,
"auto": None,
}
MODEL_TYPES = list(TYPE2FUNC.keys())

Expand All @@ -19,5 +24,18 @@ def load_clip(
device: Union[str, torch.device] = "cuda"
):
assert model_type in MODEL_TYPES, f"model_type={model_type} is invalid!"
load_func = TYPE2FUNC[model_type]
if model_type != "auto":
load_func = TYPE2FUNC[model_type]
else:
# It's a hack, but it works! you have a better way? push a PR 😃. EOM - Victor
if "synthclip" in pretrained:
load_func = TYPE2FUNC["synthclip"]
elif "scaling" in pretrained:
load_func = TYPE2FUNC["scaling"]
elif pretrained in TYPE2FUNC:
load_func = TYPE2FUNC[pretrained]
else:
print(f"{model_type} and {pretrained=} unsupported defaulting to "
"open_clip. The Lord be with you 🙏")
load_func = TYPE2FUNC["open_clip"]
return load_func(model_name=model_name, pretrained=pretrained, cache_dir=cache_dir, device=device)
51 changes: 51 additions & 0 deletions clip_benchmark/models/scaling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Load Scaling laws for synthetic data

References
- Scaling Laws of Synthetic Images for Model Training ... for Now, Fan et al., CVPR 2024
- https://github.com/google-research/syn-rep-learn/tree/main/Scaling#clip
"""
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data


from open_clip import create_model_and_transforms, get_tokenizer
from torch.nn import functional as F


def load_model(model: str = 'ViT-B-16', pretrained: str = None,
device: str = "cpu", cudnn_benchmark = True, **kwargs):
if pretrained is None:
raise FileNotFoundError(f'Failing early, missing: {pretrained}!')

tokenizer = get_tokenizer(model)
model, preprocess_train, preprocess_val = create_model_and_transforms(
model,
'',
precision='amp',
device='cuda',
jit=False,
force_quick_gelu=True,
force_custom_text=False,
force_patch_dropout=None,
force_image_size=224,
pretrained_image=False,
image_mean=None,
image_std=None,
aug_cfg={},
output_dict=True,
)
device = "cuda:0" if torch.cuda.is_available() else "cpu"
state_dict = torch.load(pretrained, map_location=device)
logit_scale = np.exp(state_dict['logit_scale'].item())
msg = model.load_state_dict(state_dict, strict=True)
print(msg)
model = model.to(device)
model.eval()
cudnn.benchmark = cudnn_benchmark
return model, preprocess_val, tokenizer


if __name__ == '__main__':
load_model(ckpt='./logs/scaling_syn_real/371M.pt')
245 changes: 245 additions & 0 deletions clip_benchmark/models/synthclip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Modified from github.com/openai/CLIP
from collections import OrderedDict
from pathlib import Path

import numpy as np
import timm
import torch
import torchvision.transforms as transforms
import open_clip
from torch import nn

# import losses


class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""

def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)


class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()

self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model)),
]
)
)
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask

def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None
else None
)
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x


class Transformer(nn.Module):
def __init__(
self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(
*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
)

def forward(self, x: torch.Tensor):
return self.resblocks(x)


class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
vision_width: int,
vision_model: nn.Module,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int,
**kwargs,
):
super().__init__()

self.context_length = context_length
self.vision_width = vision_width

self.visual = vision_model

self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)

self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width)
)
self.ln_final = LayerNorm(transformer_width)

self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

self.initialize_parameters()

def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)

proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers) ** -0.5
)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)

nn.init.normal_(self.image_projection, std=self.vision_width**-0.5)
nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)

def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask

def encode_image(self, image):
x = self.visual(image)
x = x @ self.image_projection

return x

def encode_text(self, text):
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x)

# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection

return x

def forward(self, image, text):
image_embed = self.encode_image(image)
text_embed = self.encode_text(text)

return {
"image_embed": image_embed,
"text_embed": text_embed,
"logit_scale": self.logit_scale.exp(),
}


def get_loss(gather_with_grad=False):
return losses.CLIPLoss(gather_with_grad=gather_with_grad)


def get_metric_names():
return ["loss", "clip_loss", "clip_acc"]


def CLIP_VITB16(checkpoint_path: str = None, cache_dir: str = None, **kwargs):
vision_model = timm.create_model("vit_base_patch16_224", num_classes=0,
checkpoint_path=checkpoint_path,
cache_dir=cache_dir)
model = CLIP(
embed_dim=512,
vision_width=768,
vision_model=vision_model,
context_length=77,
vocab_size=49408,
transformer_width=512,
transformer_heads=8,
transformer_layers=12,
**kwargs,
)

return model


def load_synthclip(model_name, pretrained, cache_dir, device):
if model_name == "ViT-B-16":
model = CLIP_VITB16()
tokenizer = open_clip.get_tokenizer(model_name)

if pretrained:
pretrained = Path(pretrained)
pretrained = pretrained / "checkpoint_best.pt" if pretrained.is_dir() else pretrained
state_dict = torch.load(pretrained)["state_dict"]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v

load_status = model.load_state_dict(new_state_dict)
print(f'{__name__}:{load_synthclip.__name__}, {model_name=}, {pretrained=}, {device=}, {load_status=}')

model.to(device)
val_transform = transform_pipeline()
return model, val_transform, tokenizer


def transform_pipeline():
# Taken from
# https://github.com/hammoudhasan/SynthCLIP/blob/02ef69764d8dc921650bcac4a98bd0f477790787/Training/main.py#L240
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose(
[
transforms.Lambda(lambda img: img.convert('RGB')),
transforms.Resize((224, 224)),
transforms.ColorJitter(0.4, 0.4, 0.4),
transforms.ToTensor(),
normalize,
]
)
return transform