Skip to content

Commit

Permalink
Merge pull request #173 from ViCCo-Group/align_model
Browse files Browse the repository at this point in the history
Kakaobrain Align model
  • Loading branch information
LukasMut authored May 16, 2024
2 parents 49c8f73 + 65349b5 commit 2072f28
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 2 deletions.
22 changes: 21 additions & 1 deletion docs/AvailableModels.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,24 @@ extractor = get_extractor(
pretrained=True,
model_parameters=model_parameters
)
```
```

### ALIGN model

We provide Kakaobrain's reproduction of the original [ALIGN model](https://proceedings.mlr.press/v139/jia21b.html) from [huggingface](https://huggingface.co/kakaobrain/coyo-align-b7-base).

```python
import torch
from thingsvision import get_extractor

model_name = 'Kakaobrain_Align'
source = 'custom'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

extractor = get_extractor(
model_name=model_name,
source=source,
device=device,
pretrained=True,
)
```
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ torchtyping
tqdm
dreamsim==0.1.3
git+https://github.com/openai/CLIP.git
git+https://github.com/serre-lab/Harmonization.git
git+https://github.com/serre-lab/Harmonization.git
transformers==4.40.1
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"torchtyping",
"tqdm",
"CLIP",
"transformers==4.40.1"
# 'CLIP @ git+ssh://[email protected]/openai/[email protected]#egg=CLIP' # TODO: see issue #111
]

Expand Down
6 changes: 6 additions & 0 deletions tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@
"batch_size": 1,
"num_samples": 1
},
"kakobrain_align_model": {
"model_name": "Kakaobrain_Align",
"modules": ["pooler"],
"pretrained": True,
"source": "custom",
},
}

ALIGNED_MODELS = {
Expand Down
1 change: 1 addition & 0 deletions thingsvision/custom_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .resnet50_ecoset import Resnet50_ecoset
from .vgg16_ecoset import VGG16_ecoset
from .sam import SegmentAnything
from .align import Kakaobrain_Align
19 changes: 19 additions & 0 deletions thingsvision/custom_models/align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import Any
from .custom import Custom
from transformers import AlignModel, AutoProcessor


class Kakaobrain_Align(Custom):
def __init__(self, device, parameters) -> None:
super().__init__(device)
self.backend = "pt"

def create_model(self) -> Any:
model = AlignModel.from_pretrained("kakaobrain/align-base")
processor = AutoProcessor.from_pretrained("kakaobrain/align-base")

def preprocess_fn(images):
out = processor(images=images, return_tensors="pt")
return out['pixel_values'].squeeze(0)

return model.vision_model, preprocess_fn

0 comments on commit 2072f28

Please sign in to comment.