-
Notifications
You must be signed in to change notification settings - Fork 335
Add ViViT(Video Vision Transformer) to KerasCV #2335
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aditya02shah
wants to merge
14
commits into
keras-team:master
Choose a base branch
from
aditya02shah:vivit
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
aaae396
Initialised video-classification/vivit
aditya02shah a1da121
Initialised ViViT model and add dependent layers
aditya02shah 3ccd176
Updated __init__.py
aditya02shah 9a39aa6
Added model construction and call tests
aditya02shah 30e1c8e
Updated imports
aditya02shah 82f06c3
Added tests
aditya02shah 2612a04
Added docs and some minor adjustments
aditya02shah 8099869
Updated Documentation and Default Parameters
aditya02shah 13f4829
Updated comments
aditya02shah 0b6043f
Updating parameters and build method
aditya02shah f6b2f7c
Merge branch 'keras-team:master' into vivit
aditya02shah 9536d35
Updated build.sh
aditya02shah 9f174d4
Merge branch 'keras-team:master' into vivit
aditya02shah 36541bb
Updated Build Method
aditya02shah File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Copyright 2024 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_cv.models.video_classification.vivit import ViViT |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# Copyright 2024 The KerasCV Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_cv.api_export import keras_cv_export | ||
from keras_cv.backend import keras | ||
from keras_cv.models.task import Task | ||
from keras_cv.models.video_classification.vivit_layers import PositionalEncoder | ||
from keras_cv.models.video_classification.vivit_layers import TubeletEmbedding | ||
|
||
|
||
@keras_cv_export( | ||
[ | ||
"keras_cv.models.ViViT", | ||
"keras_cv.models.video_classification.ViViT", | ||
] | ||
) | ||
class ViViT(Task): | ||
"""A Keras model implementing a Video Vision Transformer | ||
for video classification. | ||
|
||
References: | ||
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) | ||
(ICCV 2021) | ||
|
||
Args: | ||
inp_shape: tuple, the shape of the input video frames. | ||
num_classes: int, the number of classes for video classification. | ||
transformer_layers: int, the number of transformer layers in the model. | ||
Defaults to 8. | ||
patch_size: tuple , contains the size of the | ||
spatio-temporal patches for each dimension | ||
Defaults to (8,8,8) | ||
num_heads: int, the number of heads for multi-head | ||
self-attention mechanism. Defaults to 8. | ||
projection_dim: int, number of dimensions in the projection space. | ||
Defaults to 128. | ||
layer_norm_eps: float, epsilon value for layer normalization. | ||
Defaults to 1e-6. | ||
|
||
|
||
Examples: | ||
```python | ||
import keras_cv | ||
|
||
INPUT_SHAPE = (32, 32, 32, 1) | ||
NUM_CLASSES = 11 | ||
PATCH_SIZE = (8, 8, 8) | ||
LAYER_NORM_EPS = 1e-6 | ||
PROJECTION_DIM = 128 | ||
NUM_HEADS = 8 | ||
NUM_LAYERS = 8 | ||
|
||
frames = np.random.uniform(size=(5, 32, 32, 32, 1)) | ||
labels = np.ones(shape=(5)) | ||
|
||
# Instantiate Model | ||
model = ViViT( | ||
projection_dim=PROJECTION_DIM, | ||
patch_size=PATCH_SIZE, | ||
inp_shape=INPUT_SHAPE, | ||
transformer_layers=NUM_LAYERS, | ||
num_heads=NUM_HEADS, | ||
layer_norm_eps=LAYER_NORM_EPS, | ||
num_classes=NUM_CLASSES, | ||
) | ||
|
||
# Compile model | ||
model.compile( | ||
optimizer="adam", | ||
loss="sparse_categorical_crossentropy", | ||
metrics=[ | ||
keras.metrics.SparseCategoricalAccuracy(name="accuracy"), | ||
], | ||
) | ||
|
||
# Build Model | ||
model.build(INPUT_SHAPE) | ||
|
||
# Train Model | ||
model.fit(frames, labels, epochs=3) | ||
|
||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
inp_shape, | ||
num_classes, | ||
projection_dim=128, | ||
patch_size=(8, 8, 8), | ||
transformer_layers=8, | ||
num_heads=8, | ||
layer_norm_eps=1e-6, | ||
**kwargs, | ||
): | ||
super().__init__(**kwargs) | ||
|
||
self.projection_dim = projection_dim | ||
self.patch_size = patch_size | ||
self.tubelet_embedder = TubeletEmbedding( | ||
embed_dim=self.projection_dim, patch_size=self.patch_size | ||
) | ||
|
||
self.positional_encoder = PositionalEncoder( | ||
embed_dim=self.projection_dim | ||
) | ||
self.layer_norm = keras.layers.LayerNormalization( | ||
epsilon=layer_norm_eps | ||
) | ||
self.attention_output = keras.layers.MultiHeadAttention( | ||
num_heads=num_heads, | ||
key_dim=projection_dim // num_heads, | ||
dropout=0.1, | ||
) | ||
self.dense_1 = keras.layers.Dense( | ||
units=projection_dim * 4, activation=keras.ops.gelu | ||
) | ||
|
||
self.dense_2 = keras.layers.Dense( | ||
units=projection_dim, activation=keras.ops.gelu | ||
) | ||
self.add = keras.layers.Add() | ||
self.pooling = keras.layers.GlobalAvgPool1D() | ||
self.dense_output = keras.layers.Dense( | ||
units=num_classes, activation="softmax" | ||
) | ||
|
||
self.inp_shape = inp_shape | ||
self.num_heads = num_heads | ||
self.num_classes = num_classes | ||
self.projection_dim = projection_dim | ||
self.patch_size = patch_size | ||
self.transformer_layers = transformer_layers | ||
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
self.tubelet_embedder.build(input_shape) | ||
flattened_patch_shape = self.tubelet_embedder.compute_output_shape( | ||
input_shape | ||
) | ||
self.positional_encoder.build(flattened_patch_shape) | ||
self.layer_norm.build([None, None, self.projection_dim]) | ||
self.attention_output.build( | ||
query_shape=[None, None, self.projection_dim], | ||
value_shape=[None, None, self.projection_dim], | ||
) | ||
self.add.build( | ||
[ | ||
(None, None, self.projection_dim), | ||
(None, None, self.projection_dim), | ||
] | ||
) | ||
|
||
self.dense_1.build([None, None, self.projection_dim]) | ||
self.dense_2.build([None, None, self.projection_dim * 4]) | ||
self.pooling.build([None, None, self.projection_dim]) | ||
self.dense_output.build([None, self.projection_dim]) | ||
|
||
def call(self, x): | ||
patches = self.tubelet_embedder(x) | ||
encoded_patches = self.positional_encoder(patches) | ||
for _ in range(self.transformer_layers): | ||
x1 = self.layer_norm(encoded_patches) | ||
attention_output = self.attention_output(x1, x1) | ||
x2 = self.add([attention_output, encoded_patches]) | ||
x3 = self.layer_norm(x2) | ||
x4 = self.dense_1(x3) | ||
x5 = self.dense_2(x4) | ||
encoded_patches = self.add([x5, x2]) | ||
representation = self.layer_norm(encoded_patches) | ||
pooled_representation = self.pooling(representation) | ||
outputs = self.dense_output(pooled_representation) | ||
return outputs | ||
|
||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"num_heads": self.num_heads, | ||
"inp_shape": self.inp_shape, | ||
"num_classes": self.num_classes, | ||
"projection_dim": self.projection_dim, | ||
"patch_size": self.patch_size, | ||
} | ||
) | ||
return config |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright 2024 The KerasCV Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from keras_cv.api_export import keras_cv_export | ||
from keras_cv.backend import keras | ||
from keras_cv.backend import ops | ||
|
||
|
||
@keras_cv_export( | ||
"keras_cv.layers.TubeletEmebedding", | ||
package="keras_cv.layers", | ||
) | ||
class TubeletEmbedding(keras.layers.Layer): | ||
""" | ||
A Keras layer for spatio-temporal tube embedding applied to input sequences | ||
retrieved from video frames. | ||
|
||
References: | ||
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) | ||
(ICCV 2021) | ||
|
||
Args: | ||
embed_dim: int, number of dimensions in the embedding space. | ||
Defaults to 128. | ||
patch_size: tuple , size of the spatio-temporal patch. | ||
Specifies the size for each dimension. | ||
Defaults to (8,8,8). | ||
|
||
""" | ||
|
||
def __init__(self, embed_dim=128, patch_size=(8, 8, 8), **kwargs): | ||
super().__init__(**kwargs) | ||
self.embed_dim = embed_dim | ||
self.patch_size = patch_size | ||
self.projection = keras.layers.Conv3D( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. define all layers in init and build them here like |
||
filters=self.embed_dim, | ||
kernel_size=self.patch_size, | ||
strides=self.patch_size, | ||
data_format="channels_last", | ||
padding="VALID", | ||
) | ||
self.flatten = keras.layers.Reshape(target_shape=(-1, self.embed_dim)) | ||
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
self.projection.build( | ||
( | ||
None, | ||
input_shape[0], | ||
input_shape[1], | ||
input_shape[2], | ||
input_shape[3], | ||
) | ||
) | ||
projected_patch_shape = self.projection.compute_output_shape( | ||
( | ||
None, | ||
input_shape[0], | ||
input_shape[1], | ||
input_shape[2], | ||
input_shape[3], | ||
) | ||
) | ||
self.flatten.build(projected_patch_shape) | ||
|
||
def compute_output_shape(self, input_shape): | ||
projected_patch_shape = self.projection.compute_output_shape( | ||
( | ||
None, | ||
input_shape[0], | ||
input_shape[1], | ||
input_shape[2], | ||
input_shape[3], | ||
) | ||
) | ||
return self.flatten.compute_output_shape(projected_patch_shape) | ||
|
||
def call(self, videos): | ||
projected_patches = self.projection(videos) | ||
flattened_patches = self.flatten(projected_patches) | ||
return flattened_patches | ||
|
||
|
||
@keras_cv_export( | ||
"keras_cv.layers.PositionalEncoder", | ||
package="keras_cv.layers", | ||
) | ||
class PositionalEncoder(keras.layers.Layer): | ||
""" | ||
A Keras layer for adding positional information to the encoded video tokens. | ||
|
||
References: | ||
- [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) | ||
(ICCV 2021) | ||
|
||
Args: | ||
embed_dim: int, number of dimensions in the embedding space. | ||
Defaults to 128. | ||
|
||
""" | ||
|
||
def __init__(self, embed_dim=128, **kwargs): | ||
super().__init__(**kwargs) | ||
self.embed_dim = embed_dim | ||
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
_, num_tokens, _ = input_shape | ||
divyashreepathihalli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.position_embedding = keras.layers.Embedding( | ||
input_dim=num_tokens, output_dim=self.embed_dim | ||
) | ||
self.position_embedding.build(input_shape) | ||
self.positions = ops.arange(start=0, stop=num_tokens, step=1) | ||
|
||
def call(self, encoded_tokens): | ||
encoded_positions = self.position_embedding(self.positions) | ||
encoded_tokens = encoded_tokens + encoded_positions | ||
return encoded_tokens |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.