From 106d158caaa080487b4d39c134d302661a7cda65 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 20:42:05 -0800 Subject: [PATCH 01/84] Create dependencyvit.py --- timm/models/dependencyvit.py | 214 +++++++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 timm/models/dependencyvit.py diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py new file mode 100644 index 0000000000..e2083553a9 --- /dev/null +++ b/timm/models/dependencyvit.py @@ -0,0 +1,214 @@ +from typing import Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.jit import Final + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import Mlp +from timm.models.vision_transformer import VisionTransformer +from ._builder import build_model_with_cfg +from ._manipulate import checkpoint_seq, + +__all__ = ['DependencyViT'] + + +# FIXME there is nearly no difference between this and stock attn, allowing sdpa to be used if a workaround can be found +class ReversedAttention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.track_dependency_mask = False + self.dependency_mask = None + + self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias + + self.message_controller = Mlp( + in_features = dim, + hidden_features = dim/2, + out_features = 1, + act_layer = nn.GELU, + bias = False, # FIXME is there a bias term? + ) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # m is cumulative over all layers (m = m_i * m_i-1 * ... * m_1) + def forward(self, in_tuple: Tuple[torch.Tensor, Union[int, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: + x, m = in_tuple # [B, N, C], [B, 1, 1, N] + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + p = self.head_selector(x).softmax(dim=-1).transpose(-2, -1).reshape(B, self.num_heads, 1, N) + + m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn).transpose(-2, -1) + attn = attn * p * m # [B, n_h, N, N] + x = attn @ v + + self.dependency_mask = attn.sum(1) if self.track_dependency_mask else None + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return (x, m) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class DependencyVitBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_norm: bool = False, + proj_drop: float = 0., + attn_drop: float = 0., + init_values: Optional[float] = None, + drop_path: float = 0., + act_layer: nn.Module = nn.GELU, + norm_layer: nn.Module = nn.LayerNorm, + mlp_layer: nn.Module = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = ReversedAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_norm=qk_norm, + attn_drop=attn_drop, + proj_drop=proj_drop, + norm_layer=norm_layer, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = mlp_layer( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=proj_drop, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, in_tuple: Tuple[torch.Tensor, Union[int, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: + x, m = in_tuple + x, m = self.attn((self.norm1(x), m)) + x = x + self.drop_path1(self.ls1(x)) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return (x, m) + +class DependencyViT(VisionTransformer): + def __init__(self, *args, **kwargs): + super().__init__( + *args, + **kwargs, + block_fn = DependencyViTBlock, + class_token=False, + global_pool='avg', + qkv_bias=False, + init_values=1e-6, + fc_norm=False, + ) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + x = self._pos_embed(x) + x = self.patch_drop(x) + x = self.norm_pre(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x, _ = checkpoint_seq(self.blocks, (x,1)) + else: + x, _ = self.blocks((x, 1)) + x = self.norm(x) + return x + + +def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'crop_pct': 0.9, + 'interpolation': 'bicubic', + 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, + 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', + 'classifier': 'head', + **kwargs, + } + + +default_cfgs = { + 'dependencyvit_tiny_patch16_224.untrained': _cfg(url=''), +} + + +default_cfgs = generate_default_cfgs(default_cfgs) + + + +def _create_dependencyvit(variant: str, pretrained: bool = False, **kwargs) -> DependencyViT: + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + _filter_fn = checkpoint_filter_fn + + return build_model_with_cfg( + DependencyViT, + variant, + pretrained, + pretrained_filter_fn=_filter_fn, + pretrained_strict=strict, + **kwargs, + ) + + +def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model From 6f783fbd48d6cd7c753b31a4a8de301ae003c6dc Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:09:39 -0800 Subject: [PATCH 02/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index e2083553a9..61b69b9d14 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -196,14 +196,10 @@ def _create_dependencyvit(variant: str, pretrained: bool = False, **kwargs) -> D if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - _filter_fn = checkpoint_filter_fn - return build_model_with_cfg( DependencyViT, variant, pretrained, - pretrained_filter_fn=_filter_fn, - pretrained_strict=strict, **kwargs, ) From 8acb798013751e012564a0dcc5a873c618bb40dc Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:11:30 -0800 Subject: [PATCH 03/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 61b69b9d14..15628b53c2 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -9,7 +9,7 @@ from timm.layers import Mlp from timm.models.vision_transformer import VisionTransformer from ._builder import build_model_with_cfg -from ._manipulate import checkpoint_seq, +from ._manipulate import checkpoint_seq __all__ = ['DependencyViT'] From 48b3a40c1d1bcb094230625b82db43f47f376d11 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:12:57 -0800 Subject: [PATCH 04/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 15628b53c2..31e244d09a 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn From 95e47be6368fd5df58fc6314076fc8a756a11357 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:14:45 -0800 Subject: [PATCH 05/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 31e244d09a..609b9d3859 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch.jit import Final -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import Mlp from timm.models.vision_transformer import VisionTransformer from ._builder import build_model_with_cfg From 812c7281e3e675e25972b18cabd6ea86dc3ef84b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:16:28 -0800 Subject: [PATCH 06/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 609b9d3859..274393afdb 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -10,6 +10,7 @@ from timm.models.vision_transformer import VisionTransformer from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq +from ._registry import generate_default_cfgs, register_model __all__ = ['DependencyViT'] @@ -203,7 +204,7 @@ def _create_dependencyvit(variant: str, pretrained: bool = False, **kwargs) -> D **kwargs, ) - +@register_model def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) From 2b5ba76a7d0c11b847f61e00af6ba5a395ddd600 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:18:36 -0800 Subject: [PATCH 07/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 274393afdb..3dec9ba3d8 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -94,7 +94,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma -class DependencyVitBlock(nn.Module): +class DependencyViTBlock(nn.Module): def __init__( self, dim: int, From d2e3ba81b250e06af39c48b4f13c991e73e5489c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:26:51 -0800 Subject: [PATCH 08/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 3dec9ba3d8..f6dbae8995 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -113,7 +113,7 @@ def __init__( super().__init__() self.norm1 = norm_layer(dim) self.attn = ReversedAttention( - dim, + dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, From c7039fb11e88952f207040ac49a63ab6678d5210 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 21:32:14 -0800 Subject: [PATCH 09/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index f6dbae8995..9d8b9f3055 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -38,9 +38,10 @@ def __init__( self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias + print(dim) self.message_controller = Mlp( in_features = dim, - hidden_features = dim/2, + hidden_features = int(dim/2), out_features = 1, act_layer = nn.GELU, bias = False, # FIXME is there a bias term? From a53a3fc218bde3e3ed592803cc3772ce44792274 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 17 Dec 2023 22:00:53 -0800 Subject: [PATCH 10/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 9d8b9f3055..aae37ec3b2 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -38,7 +38,6 @@ def __init__( self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias - print(dim) self.message_controller = Mlp( in_features = dim, hidden_features = int(dim/2), From 5d9d863ed1106b8cc15e6e8588bce0aef2d43485 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 00:07:15 -0800 Subject: [PATCH 11/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index aae37ec3b2..04b193e5c5 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -54,7 +54,7 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) # m is cumulative over all layers (m = m_i * m_i-1 * ... * m_1) - def forward(self, in_tuple: Tuple[torch.Tensor, Union[int, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x, m = in_tuple # [B, N, C], [B, 1, 1, N] B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) @@ -134,7 +134,7 @@ def __init__( self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - def forward(self, in_tuple: Tuple[torch.Tensor, Union[int, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x, m = in_tuple x, m = self.attn((self.norm1(x), m)) x = x + self.drop_path1(self.ls1(x)) @@ -159,10 +159,11 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) + m = torch.Tensor(1).to(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x, _ = checkpoint_seq(self.blocks, (x,1)) + x, _ = checkpoint_seq(self.blocks, (x, m)) else: - x, _ = self.blocks((x, 1)) + x, _ = self.blocks((x, m)) x = self.norm(x) return x From 7749432d5934cb31e66eff529e71d06bc7ae30df Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 00:11:41 -0800 Subject: [PATCH 12/84] Update __init__.py --- timm/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index c5b1984f20..59e812b038 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -11,6 +11,7 @@ from .davit import * from .deit import * from .densenet import * +from .dependencyvit import * from .dla import * from .dpn import * from .edgenext import * From 779a1cdb77fbe121eadc536d4179f12c9d7d4b9e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 00:15:05 -0800 Subject: [PATCH 13/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 04b193e5c5..48e485f030 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn @@ -34,7 +34,7 @@ def __init__( self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False - self.dependency_mask = None + self.dependency_mask: Optional[Tensor] = None self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias From 282e7a3194b237e7d02750a0a8698c691c423bd8 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 00:17:23 -0800 Subject: [PATCH 14/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 48e485f030..3b4ef31864 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -17,6 +17,7 @@ # FIXME there is nearly no difference between this and stock attn, allowing sdpa to be used if a workaround can be found class ReversedAttention(nn.Module): + dependency_mask: Optional[torch.Tensor] def __init__( self, @@ -34,7 +35,7 @@ def __init__( self.head_dim = dim // num_heads self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False - self.dependency_mask: Optional[Tensor] = None + self.dependency_mask = None self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias From 96267e154a3e82c9ab65379839936a562f756a1a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 00:35:57 -0800 Subject: [PATCH 15/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 3b4ef31864..a78e7d819f 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -142,6 +142,10 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return (x, m) +# FIXME lite model variants +# FIXME toggle and retrieve dependency masks +# FIXME verify against reference impl + class DependencyViT(VisionTransformer): def __init__(self, *args, **kwargs): super().__init__( @@ -208,6 +212,12 @@ def _create_dependencyvit(variant: str, pretrained: bool = False, **kwargs) -> D @register_model def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: - model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12) model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model + +@register_model +def dependencyvit_small_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: + model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12) + model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model \ No newline at end of file From 245a842085616702ac71cd4ab7fcfd3a1693d92d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 01:13:52 -0800 Subject: [PATCH 16/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index a78e7d819f..fafbfb35e5 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -1,3 +1,15 @@ +""" DependencyViT (FIXME WIP) + +From-scratch implementation of DependencyViT in PyTorch + +'Visual Dependency Transformers: Dependency Tree Emerges from Reversed Attention' + - https://arxiv.org/abs/2304.03282 + +ReversedAttention implementation derived from timm's Vision Transformer implementation + +Implementation for timm by / Copyright 2023, Fredo Guan +""" + from typing import Any, Dict, Optional, Tuple import torch @@ -36,6 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None + self.head_selector_temperature = 0.1 # appendix D.1 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias @@ -62,7 +75,8 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) - p = self.head_selector(x).softmax(dim=-1).transpose(-2, -1).reshape(B, self.num_heads, 1, N) + p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) + p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) From 8766e362f99071c336310fdaf1f97b73d0f7cf88 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 02:05:17 -0800 Subject: [PATCH 17/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index fafbfb35e5..7c62b2ebff 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -18,7 +18,7 @@ from torch.jit import Final from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from timm.layers import Mlp +from timm.layers import DropPath, Mlp from timm.models.vision_transformer import VisionTransformer from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq From 2dc84d44dc6942ed37c40e95d65a1252511883d2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:03:37 -0800 Subject: [PATCH 18/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 7c62b2ebff..7ede7f84d3 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -180,10 +180,11 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.norm_pre(x) m = torch.Tensor(1).to(x) if self.grad_checkpointing and not torch.jit.is_scripting(): - x, _ = checkpoint_seq(self.blocks, (x, m)) + x, m = checkpoint_seq(self.blocks, (x, m)) else: - x, _ = self.blocks((x, m)) + x, m = self.blocks((x, m)) x = self.norm(x) + x = x * m.transpose(1, 3).squeeze(-1) return x From ef126b66f34c307fa6fa63955e2add54a26fc245 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:07:15 -0800 Subject: [PATCH 19/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 7ede7f84d3..4115eb3f81 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -184,7 +184,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) x = self.norm(x) - x = x * m.transpose(1, 3).squeeze(-1) + #x = x * m.transpose(1, 3).squeeze(-1) return x From 87590485bffc07d3c2ee268e4572b3d13e8ef9dd Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:10:02 -0800 Subject: [PATCH 20/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 4115eb3f81..b22ac39425 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -75,7 +75,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) - p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) + p = (self.head_selector(x) * self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) @@ -184,7 +184,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) x = self.norm(x) - #x = x * m.transpose(1, 3).squeeze(-1) + x = x * m.transpose(1, 3).squeeze(-1) return x From f6db3bac9adba84e8ae98794bbca421dd65546ee Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:11:51 -0800 Subject: [PATCH 21/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index b22ac39425..c8164e5343 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -184,7 +184,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) x = self.norm(x) - x = x * m.transpose(1, 3).squeeze(-1) + #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm return x From 3f046fef896cfe12370669ecb815c703a49a3106 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:15:20 -0800 Subject: [PATCH 22/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index c8164e5343..8d79bb2894 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -75,7 +75,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te q, k, v = qkv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) - p = (self.head_selector(x) * self.head_selector_temperature).softmax(dim=-1) + p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) From 5a5f6bbf9caf15f8fc3c70a98aa80b28b652169b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:17:04 -0800 Subject: [PATCH 23/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 8d79bb2894..354ffb7ee6 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -184,7 +184,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) x = self.norm(x) - #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm + x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm return x From c9d1666644140baeb58f5e6be35475ab93e2658d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:20:11 -0800 Subject: [PATCH 24/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 354ffb7ee6..8d79bb2894 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -184,7 +184,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) x = self.norm(x) - x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm + #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm return x From e3bb964947e41122a09e510820563548315aed4d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:22:52 -0800 Subject: [PATCH 25/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 8d79bb2894..dc6d3b6f37 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -183,8 +183,10 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x, m = checkpoint_seq(self.blocks, (x, m)) else: x, m = self.blocks((x, m)) + + x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm + x = self.norm(x) - #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm return x From b564b30622888c7f836a9bcbd59294cf2bd0ba67 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:25:16 -0800 Subject: [PATCH 26/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index dc6d3b6f37..50c042b186 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 0.1 # appendix D.1 + self.head_selector_temperature = 1 # appendix D.1 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias @@ -184,7 +184,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) - x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm + #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) return x From a3442bff9c8f208c1336ac5dc4cc552dff0c4158 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:27:35 -0800 Subject: [PATCH 27/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 50c042b186..a5ace99efc 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -184,9 +184,10 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) - #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) + x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm + return x From 15df11e3da26e6bc9a410054b2ab7588668136b4 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:30:21 -0800 Subject: [PATCH 28/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index a5ace99efc..606fa55794 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -184,10 +184,9 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) - - x = self.norm(x) x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm + x = self.norm(x) return x From b1eeaaffe8f9012f03d64effa2a3d10604fb4030 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:36:01 -0800 Subject: [PATCH 29/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 606fa55794..53103426a9 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 1 # appendix D.1 + self.head_selector_temperature = 0.1 # appendix D.1 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias @@ -184,7 +184,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: else: x, m = self.blocks((x, m)) - x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm + #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) return x From 328290ceca2e8d097956ac6548217551e4c732ba Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:37:17 -0800 Subject: [PATCH 30/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 53103426a9..17af62193d 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 0.1 # appendix D.1 + self.head_selector_temperature = 10.0 # appendix D.1 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias From 6e19187eae8cde4cae80353851c80d9787eef162 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:38:42 -0800 Subject: [PATCH 31/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 17af62193d..b673d68e67 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 10.0 # appendix D.1 + self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias From 7a6d02463b39b1831f5a2e20ced9700c04a0a063 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:42:20 -0800 Subject: [PATCH 32/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index b673d68e67..7db4cc0ab6 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 10.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias @@ -187,6 +187,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) + x = x * m.transpose(1, 3).squeeze(-1) return x From 4d5b00c32d0ef80ae122e893a9a5c8582e16ddcb Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 03:51:50 -0800 Subject: [PATCH 33/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 7db4cc0ab6..a244071c96 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 10.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias @@ -187,7 +187,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) - x = x * m.transpose(1, 3).squeeze(-1) + #x = x * m.transpose(1, 3).squeeze(-1) return x From 4ebd1ce7ba600d77765a8b0ec6d3277ec7826bb7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 04:16:36 -0800 Subject: [PATCH 34/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index a244071c96..d17fa4466b 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0 self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) + m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) q = q * self.scale attn = q @ k.transpose(-2, -1) From 1c449c0d054e747ac5d276293f06238ca9fc358e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 10:54:09 -0800 Subject: [PATCH 35/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index d17fa4466b..07d46bf794 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -50,14 +50,14 @@ def __init__( self.dependency_mask = None self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0 - self.head_selector = nn.Linear(dim, num_heads, bias=False) # paper only mentions a weight matrix, assuming no bias + self.head_selector = nn.Linear(dim, num_heads) self.message_controller = Mlp( in_features = dim, hidden_features = int(dim/2), out_features = 1, act_layer = nn.GELU, - bias = False, # FIXME is there a bias term? + bias = True, # FIXME is there a bias term? ) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) From 8d588bb711aba6cd5bbab1482bf4389501cd1609 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 13:13:16 -0800 Subject: [PATCH 36/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 07d46bf794..bef635f84b 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 self.head_selector = nn.Linear(dim, num_heads) @@ -151,8 +151,8 @@ def __init__( def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x, m = in_tuple - x, m = self.attn((self.norm1(x), m)) - x = x + self.drop_path1(self.ls1(x)) + x_new, m = self.attn((self.norm1(x), m)) + x = x + self.drop_path1(self.ls1(x_new)) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return (x, m) From a44139082a16b654480afee99844fa9ec8d2e5e5 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 17:15:44 -0800 Subject: [PATCH 37/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index bef635f84b..b58e781000 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0 self.head_selector = nn.Linear(dim, num_heads) @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) + m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -187,7 +187,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) - #x = x * m.transpose(1, 3).squeeze(-1) + x = x * m.transpose(1, 3).squeeze(-1) return x From 72faca138de2cbd34b17ac9507ea888127a6e4e6 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 18 Dec 2023 19:29:59 -0800 Subject: [PATCH 38/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index b58e781000..8b94068286 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) + m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -187,7 +187,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) - x = x * m.transpose(1, 3).squeeze(-1) + #x = x * m.transpose(1, 3).squeeze(-1) return x From 14743313c49e5fec7b7ba56caaaf4d89399a0ec9 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 20 Dec 2023 05:32:57 -0800 Subject: [PATCH 39/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 8b94068286..bef635f84b 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,7 +48,7 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 self.head_selector = nn.Linear(dim, num_heads) From 30c370e27818ea721a1e02844b993314fc0f0087 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 02:33:53 -0800 Subject: [PATCH 40/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index bef635f84b..66e2e5c244 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -48,16 +48,16 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 1.0 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0 - self.head_selector = nn.Linear(dim, num_heads) + self.head_selector = nn.Linear(dim, num_heads, bias=False) self.message_controller = Mlp( in_features = dim, hidden_features = int(dim/2), out_features = 1, act_layer = nn.GELU, - bias = True, # FIXME is there a bias term? + bias = False, # FIXME is there a bias term? ) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) @@ -169,7 +169,7 @@ def __init__(self, *args, **kwargs): class_token=False, global_pool='avg', qkv_bias=False, - init_values=1e-6, + init_values=None, fc_norm=False, ) From 1e8beb1c905584720056ba78bc2f61d680f1ab2c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 03:15:32 -0800 Subject: [PATCH 41/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 66e2e5c244..35ff6ebb33 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -83,7 +83,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te q = q * self.scale attn = q @ k.transpose(-2, -1) attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn).transpose(-2, -1) + attn = self.attn_drop(attn).transpose(-2, -1) # this transpose prevents use of sdpa attn = attn * p * m # [B, n_h, N, N] x = attn @ v From 795878efc6f463bc7e8904ab8210619248ff0795 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:03:48 -0800 Subject: [PATCH 42/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 35ff6ebb33..edc19d2f4b 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) + m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) * m q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -169,7 +169,7 @@ def __init__(self, *args, **kwargs): class_token=False, global_pool='avg', qkv_bias=False, - init_values=None, + init_values=1e-6, fc_norm=False, ) @@ -187,7 +187,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: #x = x * m.transpose(1, 3).squeeze(-1) # FIXME before or after norm x = self.norm(x) - #x = x * m.transpose(1, 3).squeeze(-1) + x = x * m.transpose(1, 3).squeeze(-1) return x From a067d19d49a5b7a632074a8fa7691ab1a712976a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:06:20 -0800 Subject: [PATCH 43/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index edc19d2f4b..393a9a4b2e 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) * m + m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N)# * m q = q * self.scale attn = q @ k.transpose(-2, -1) From f179ef76df6de0342104b108e95bbedf37e56608 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:35:16 -0800 Subject: [PATCH 44/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 393a9a4b2e..edc19d2f4b 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N)# * m + m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) * m q = q * self.scale attn = q @ k.transpose(-2, -1) From b662b52a5bbd020b4cb9f1ebd0467253e86ccd1e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:35:33 -0800 Subject: [PATCH 45/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index edc19d2f4b..5bf00d5647 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -169,7 +169,7 @@ def __init__(self, *args, **kwargs): class_token=False, global_pool='avg', qkv_bias=False, - init_values=1e-6, + init_values=None, fc_norm=False, ) From 6a373acdfd2c8db1cfbd2589e8515bae569a1b63 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:39:07 -0800 Subject: [PATCH 46/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 5bf00d5647..82977a6556 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -154,6 +154,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x_new, m = self.attn((self.norm1(x), m)) x = x + self.drop_path1(self.ls1(x_new)) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + print((x, m)) return (x, m) # FIXME lite model variants From 08c0e392a72e02db7e43fedb1e94d85c1068f20b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:50:27 -0800 Subject: [PATCH 47/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 82977a6556..65d216dee4 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -79,6 +79,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) * m + print(m) q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -154,7 +155,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x_new, m = self.attn((self.norm1(x), m)) x = x + self.drop_path1(self.ls1(x_new)) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - print((x, m)) + #print((x, m)) return (x, m) # FIXME lite model variants From b5134bfcd681a346aa904102e0662c2f05ec0dd7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:53:52 -0800 Subject: [PATCH 48/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 65d216dee4..d39fa2ee50 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N) * m + m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N)# * m print(m) q = q * self.scale From 41ddbded65597f108da2fc5b713ffccf94171bc3 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 04:56:09 -0800 Subject: [PATCH 49/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index d39fa2ee50..84f4398e22 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -77,7 +77,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - + print(m) m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N)# * m print(m) From 5ae9513a92e56e981097bb9a34165a21cd18a2a1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 05:00:43 -0800 Subject: [PATCH 50/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 84f4398e22..933da49c3b 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -78,7 +78,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) print(m) - m = self.message_controller(x).sigmoid().reshape(B, 1, 1, N)# * m + m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N)# * m print(m) q = q * self.scale From 5970607b6a816d705907162f98bd9631a32dfa33 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 05:04:32 -0800 Subject: [PATCH 51/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 933da49c3b..b022d8ef89 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -180,7 +180,8 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - m = torch.Tensor(1).to(x) + B, N, _ = x.shape + m = torch.ones(B, 1, 1, N).to(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x, m = checkpoint_seq(self.blocks, (x, m)) else: From fb7848d54d26024f7abefc7e0b542bd32d028d17 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 05:08:27 -0800 Subject: [PATCH 52/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index b022d8ef89..a80938e055 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -77,9 +77,8 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - print(m) - m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N)# * m - print(m) + + m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -155,7 +154,6 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x_new, m = self.attn((self.norm1(x), m)) x = x + self.drop_path1(self.ls1(x_new)) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) - #print((x, m)) return (x, m) # FIXME lite model variants From 68c8aaa14ea2e4d04901114436a6d66db5e98b67 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 05:08:54 -0800 Subject: [PATCH 53/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index a80938e055..711c2f1cb4 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -169,7 +169,7 @@ def __init__(self, *args, **kwargs): class_token=False, global_pool='avg', qkv_bias=False, - init_values=None, + init_values=1e-6, fc_norm=False, ) From 1c0b10c59fde73e5eafa9efd1302d8b130cf79d7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 08:20:23 -0800 Subject: [PATCH 54/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 711c2f1cb4..d38f7bbdba 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -77,7 +77,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - + print(m) m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) q = q * self.scale @@ -179,7 +179,8 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_drop(x) x = self.norm_pre(x) B, N, _ = x.shape - m = torch.ones(B, 1, 1, N).to(x) + #m = torch.ones(B, 1, 1, N).to(x) + m = torch.Tensor([1]).to(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x, m = checkpoint_seq(self.blocks, (x, m)) else: From 46b204f281e3837a70dce54d0eb921e11ac44369 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 23 Dec 2023 08:23:18 -0800 Subject: [PATCH 55/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index d38f7bbdba..7533fe2168 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -77,7 +77,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te p = (self.head_selector(x) / self.head_selector_temperature).softmax(dim=-1) p = p.transpose(-2, -1).reshape(B, self.num_heads, 1, N) - print(m) + m = m * self.message_controller(x).sigmoid().reshape(B, 1, 1, N) q = q * self.scale From 79bf1b254033815909bc32af97db3ac30e68d429 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:13:15 -0800 Subject: [PATCH 56/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 87 ++++++++++++++++++++++++++++++++---- 1 file changed, 79 insertions(+), 8 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 7533fe2168..00b7f85c2d 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -10,7 +10,8 @@ Implementation for timm by / Copyright 2023, Fredo Guan """ -from typing import Any, Dict, Optional, Tuple +import math +from typing import Any, Dict, Optional, Tuple, Union import torch import torch.nn as nn @@ -26,8 +27,22 @@ __all__ = ['DependencyViT'] +class TokenPruner(nn.Module): + def __init__( + self, + prune_ratio: float, + prune_index: int, + ): + super().__init__() + self.pct_kept_tokens = (1 - prune_index * prune_ratio) / (1 - (prune_index - 1) * prune_ratio) + + def forward(self, x: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, N] + _, N, C = x.shape + topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False) # [B, N'] + topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, C) # [B, N', C] + return x.gather(1, topk_indices) + -# FIXME there is nearly no difference between this and stock attn, allowing sdpa to be used if a workaround can be found class ReversedAttention(nn.Module): dependency_mask: Optional[torch.Tensor] @@ -48,9 +63,9 @@ def __init__( self.scale = self.head_dim ** -0.5 self.track_dependency_mask = False self.dependency_mask = None - self.head_selector_temperature = 0.1 # appendix D.1, causes nan when 0.1, 0 when 10.0 + self.head_selector_temperature = 0.1 # appendix D.1 - self.head_selector = nn.Linear(dim, num_heads, bias=False) + self.head_selector = nn.Linear(dim, num_heads, bias=False) # FIXME is there a bias term? self.message_controller = Mlp( in_features = dim, @@ -59,7 +74,9 @@ def __init__( act_layer = nn.GELU, bias = False, # FIXME is there a bias term? ) - + + self.token_pruner = None + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -86,8 +103,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te attn = self.attn_drop(attn).transpose(-2, -1) # this transpose prevents use of sdpa attn = attn * p * m # [B, n_h, N, N] x = attn @ v - - self.dependency_mask = attn.sum(1) if self.track_dependency_mask else None + + # FIXME messy way to handle + if self.track_dependency_mask or not isinstance(self.token_pruner, nn.Identity()): + dependency_mask = attn.detach().sum(1) # [B, N, N] + self.dependency_mask = dependency_mask if self.track_dependency_mask else None + #FIXME how to prune + x = self.token_pruner(x, dependency_mask.sum(-1)) if self.token_pruner else x # dependency mask weights(sum) + #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum) + #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum) + #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m + x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) @@ -161,7 +187,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te # FIXME verify against reference impl class DependencyViT(VisionTransformer): - def __init__(self, *args, **kwargs): + def __init__( + self, + prune_layers: Optional[Union[List[int], Tuple[int]]] = None, + prune_ratio: Optional[float] = None, + *args, + **kwargs + ): super().__init__( *args, **kwargs, @@ -172,6 +204,19 @@ def __init__(self, *args, **kwargs): init_values=1e-6, fc_norm=False, ) + + if prune_layers is not None: + self.prune_layers = sorted(list(dict.fromkeys(prune_layers))) + self.prune_ratio = prune_ratio + + # FIXME reword these assertions + assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices are greater than model depth" + assert self.prune_ratio * len(self.prune_layers) < 1, "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1" + + self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess + for prune_index, layer in enumerate(prune_layers, 1): + self.blocks[layer].attn.token_pruner = TokenPruner(self.prune_ratio, prune_index) + def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) @@ -191,6 +236,23 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.norm(x) x = x * m.transpose(1, 3).squeeze(-1) return x + + def track_dependency_mask(self, track: bool = True): + for block in self.blocks: + if block.attn.track_dependency_mask is not track: + block.attn.dependency_mask = None + block.attn.track_dependency_mask = track + + def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = None): + # L' * [B, N, N] + # L' * [B, N', N'] + result = [] + layers = range(len(self.blocks)) if not layers + for layer in layers: + result.append(self.blocks[layer].attn.dependency_mask) + return result + + def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: @@ -212,6 +274,9 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]: default_cfgs = { 'dependencyvit_tiny_patch16_224.untrained': _cfg(url=''), + 'dependencyvit_small_patch16_224.untrained': _cfg(url=''), + + 'dependencyvit_lite_tiny_patch16_224.untrained': _cfg(url=''), } @@ -240,4 +305,10 @@ def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> Depend def dependencyvit_small_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12) model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def dependencyvit_lite_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12, prune_layers=[2, 5, 8, 11], prune_ratio=0.16) + model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model \ No newline at end of file From fddd4c840591119515b3458f8274dc88e798fc90 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:15:20 -0800 Subject: [PATCH 57/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 00b7f85c2d..2b00b39a4f 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -247,7 +247,7 @@ def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = N # L' * [B, N, N] # L' * [B, N', N'] result = [] - layers = range(len(self.blocks)) if not layers + layers = layers if layers else range(len(self.blocks)) for layer in layers: result.append(self.blocks[layer].attn.dependency_mask) return result From 0effbcebb5f2795ae17ad7fa0fa4261ad2cd03e2 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:16:44 -0800 Subject: [PATCH 58/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 2b00b39a4f..cfd6149f67 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -11,7 +11,7 @@ """ import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn From e988a7923af2242e5d09f8e3c81d3fed35b2cca4 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:22:47 -0800 Subject: [PATCH 59/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index cfd6149f67..afd6ec74a2 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -105,7 +105,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x = attn @ v # FIXME messy way to handle - if self.track_dependency_mask or not isinstance(self.token_pruner, nn.Identity()): + if self.track_dependency_mask or self.token_pruner: dependency_mask = attn.detach().sum(1) # [B, N, N] self.dependency_mask = dependency_mask if self.track_dependency_mask else None #FIXME how to prune From 9ccf009b390cad7ab9340bd55d62b620f5c1cb59 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:25:53 -0800 Subject: [PATCH 60/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index afd6ec74a2..992e155633 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -103,6 +103,8 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te attn = self.attn_drop(attn).transpose(-2, -1) # this transpose prevents use of sdpa attn = attn * p * m # [B, n_h, N, N] x = attn @ v + x = x.transpose(1, 2).reshape(B, N, C) + # FIXME messy way to handle if self.track_dependency_mask or self.token_pruner: @@ -115,7 +117,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m - x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) x = self.proj_drop(x) return (x, m) From 007cd95548fc3fab313d704282ca1d9689f2b1b7 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:27:45 -0800 Subject: [PATCH 61/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 992e155633..6985d131b6 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -38,7 +38,7 @@ def __init__( def forward(self, x: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, N] _, N, C = x.shape - topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False) # [B, N'] + topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N'] topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, C) # [B, N', C] return x.gather(1, topk_indices) From 5f3b70b96f813eac0d5de40bd680d4c2e4470393 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:46:16 -0800 Subject: [PATCH 62/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 6985d131b6..f23c211be6 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -75,7 +75,7 @@ def __init__( bias = False, # FIXME is there a bias term? ) - self.token_pruner = None + #self.token_pruner = None self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() @@ -105,7 +105,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) - + ''' # FIXME messy way to handle if self.track_dependency_mask or self.token_pruner: dependency_mask = attn.detach().sum(1) # [B, N, N] @@ -115,12 +115,17 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum) #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum) #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m - - + ''' + self.dependency_mask = attn.detach().sum(1) if self.track_dependency_mask else None # [B, N, N] + + prune_mask = attn.detach().sum(1).sum(-1) + #prune_mask = attn.detach().sum(1).abs().sum(-1) + #prune_mask = attn.detach().abs().sum(1).sum(-1) + #prune_mask = m.reshape(B, N) x = self.proj(x) x = self.proj_drop(x) - return (x, m) + return (x, m, prune_mask) class LayerScale(nn.Module): def __init__( @@ -166,6 +171,8 @@ def __init__( ) self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.token_pruner = None self.norm2 = norm_layer(dim) self.mlp = mlp_layer( @@ -179,8 +186,9 @@ def __init__( def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x, m = in_tuple - x_new, m = self.attn((self.norm1(x), m)) + x_new, m, prune_mask = self.attn((self.norm1(x), m)) x = x + self.drop_path1(self.ls1(x_new)) + x = self.token_pruner(x, prune_mask) if self.token_pruner else x x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return (x, m) @@ -217,7 +225,7 @@ def __init__( self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess for prune_index, layer in enumerate(prune_layers, 1): - self.blocks[layer].attn.token_pruner = TokenPruner(self.prune_ratio, prune_index) + self.blocks[layer].token_pruner = TokenPruner(self.prune_ratio, prune_index) def forward_features(self, x: torch.Tensor) -> torch.Tensor: From 25a501d1c016d63925e5205ba1a5ccfb7641302b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 09:56:34 -0800 Subject: [PATCH 63/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index f23c211be6..7067478802 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -36,11 +36,12 @@ def __init__( super().__init__() self.pct_kept_tokens = (1 - prune_index * prune_ratio) / (1 - (prune_index - 1) * prune_ratio) - def forward(self, x: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, N] - _, N, C = x.shape + def forward(self, x: torch.Tensor, m: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, 1, 1, N], [B, N] + B, N, C = x.shape topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N'] - topk_indices = topk_indices.unsqueeze(-1).expand(-1, -1, C) # [B, N', C] - return x.gather(1, topk_indices) + x = x.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, C)) # [B, N', C] + m = m.gather(3, topk_indices.unsqueeze(1).unsqueeze(1)) # [B, 1, 1, N'] + return (x, m) class ReversedAttention(nn.Module): @@ -188,7 +189,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x, m = in_tuple x_new, m, prune_mask = self.attn((self.norm1(x), m)) x = x + self.drop_path1(self.ls1(x_new)) - x = self.token_pruner(x, prune_mask) if self.token_pruner else x + x, m = self.token_pruner(x, m, prune_mask) if self.token_pruner else (x, m) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return (x, m) From 4e28e341fe52c57ced1ab6b55aa6d4a88f031f17 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 24 Dec 2023 20:32:38 -0800 Subject: [PATCH 64/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 7067478802..9ffd14a218 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -121,7 +121,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te prune_mask = attn.detach().sum(1).sum(-1) #prune_mask = attn.detach().sum(1).abs().sum(-1) - #prune_mask = attn.detach().abs().sum(1).sum(-1) + #prune_mask = attn.detach().abs().sum((1, -1)) #prune_mask = m.reshape(B, N) x = self.proj(x) From 130e3bb23e798561dc176c1834971be8c9bcc50c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 04:36:07 -0800 Subject: [PATCH 65/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 9ffd14a218..6b9b74417c 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -225,7 +225,7 @@ def __init__( assert self.prune_ratio * len(self.prune_layers) < 1, "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1" self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess - for prune_index, layer in enumerate(prune_layers, 1): + for prune_index, layer in enumerate(self.prune_layers, 1): self.blocks[layer].token_pruner = TokenPruner(self.prune_ratio, prune_index) From 90a34974e192718fb5cd5f345123e49d83f32d9a Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 05:25:21 -0800 Subject: [PATCH 66/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 6b9b74417c..65226287ac 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -122,6 +122,9 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te prune_mask = attn.detach().sum(1).sum(-1) #prune_mask = attn.detach().sum(1).abs().sum(-1) #prune_mask = attn.detach().abs().sum((1, -1)) + #prune_mask = attn.sum(1).sum(-1) + #prune_mask = attn.sum(1).abs().sum(-1) + #prune_mask = attn.abs().sum((1, -1)) #prune_mask = m.reshape(B, N) x = self.proj(x) From 8212e96fd14300ee7d6c4764ca78853f28ea9bc0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 07:04:09 -0800 Subject: [PATCH 67/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 65226287ac..395ae9580a 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -119,12 +119,12 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te ''' self.dependency_mask = attn.detach().sum(1) if self.track_dependency_mask else None # [B, N, N] - prune_mask = attn.detach().sum(1).sum(-1) + #prune_mask = attn.detach().sum(1).sum(-1) #prune_mask = attn.detach().sum(1).abs().sum(-1) #prune_mask = attn.detach().abs().sum((1, -1)) #prune_mask = attn.sum(1).sum(-1) #prune_mask = attn.sum(1).abs().sum(-1) - #prune_mask = attn.abs().sum((1, -1)) + prune_mask = attn.abs().sum((1, -1)) #prune_mask = m.reshape(B, N) x = self.proj(x) From 73dcda935e444e29e925dd67f178f1e944fa9819 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 14:01:47 -0800 Subject: [PATCH 68/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 395ae9580a..95d6b1dd91 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -124,8 +124,8 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te #prune_mask = attn.detach().abs().sum((1, -1)) #prune_mask = attn.sum(1).sum(-1) #prune_mask = attn.sum(1).abs().sum(-1) - prune_mask = attn.abs().sum((1, -1)) - #prune_mask = m.reshape(B, N) + #prune_mask = attn.abs().sum((1, -1)) + prune_mask = m.reshape(B, N) x = self.proj(x) x = self.proj_drop(x) From 1107c6961f370c869abfe1d14eab2a1352d4527b Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 14:05:13 -0800 Subject: [PATCH 69/84] Update test_models.py --- tests/test_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_models.py b/tests/test_models.py index 0b7303c548..481f741827 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -62,6 +62,7 @@ 'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*', 'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*', + 'dependencyvit_*', ] NUM_NON_STD = len(NON_STD_FILTERS) From 3b1604f3549f99dba1f7aa30d2bcd30bba7f10ed Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 15:30:37 -0800 Subject: [PATCH 70/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 37 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 95d6b1dd91..31caea6204 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -1,4 +1,4 @@ -""" DependencyViT (FIXME WIP) +""" DependencyViT From-scratch implementation of DependencyViT in PyTorch @@ -106,19 +106,13 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x = attn @ v x = x.transpose(1, 2).reshape(B, N, C) - ''' - # FIXME messy way to handle - if self.track_dependency_mask or self.token_pruner: - dependency_mask = attn.detach().sum(1) # [B, N, N] - self.dependency_mask = dependency_mask if self.track_dependency_mask else None - #FIXME how to prune - x = self.token_pruner(x, dependency_mask.sum(-1)) if self.token_pruner else x # dependency mask weights(sum) - #x = self.token_pruner(x, dependency_mask.abs().sum(-1)) if self.token_pruner else x # dependency mask weights(abs-sum) - #x = self.token_pruner(x, attn.detach().abs().sum(1).abs().sum(-1)) if self.token_pruner else x # attn weights(abs-sum-abs-sum) - #x = self.token_pruner(x, m.reshape(B, N)) if self.token_pruner else x # m - ''' + + #FIXME absolute value? self.dependency_mask = attn.detach().sum(1) if self.track_dependency_mask else None # [B, N, N] + #FIXME which pruning mask? + + # [B, N] #prune_mask = attn.detach().sum(1).sum(-1) #prune_mask = attn.detach().sum(1).abs().sum(-1) #prune_mask = attn.detach().abs().sum((1, -1)) @@ -196,9 +190,9 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return (x, m) -# FIXME lite model variants -# FIXME toggle and retrieve dependency masks + # FIXME verify against reference impl +# FIXME train weights that meet or exceed results from paper class DependencyViT(VisionTransformer): def __init__( @@ -207,15 +201,15 @@ def __init__( prune_ratio: Optional[float] = None, *args, **kwargs - ): + ): -> None: super().__init__( - *args, + *args, **kwargs, - block_fn = DependencyViTBlock, + block_fn = DependencyViTBlock, class_token=False, - global_pool='avg', - qkv_bias=False, - init_values=1e-6, + global_pool='avg', + qkv_bias=False, + init_values=1e-6, fc_norm=False, ) @@ -223,8 +217,7 @@ def __init__( self.prune_layers = sorted(list(dict.fromkeys(prune_layers))) self.prune_ratio = prune_ratio - # FIXME reword these assertions - assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices are greater than model depth" + assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices exceed model depth" assert self.prune_ratio * len(self.prune_layers) < 1, "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1" self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess From 2dd2ec52496130112d1e0436aaa1198660b0d5a1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 16:02:51 -0800 Subject: [PATCH 71/84] fix syntax, type and shape annotations --- timm/models/dependencyvit.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 31caea6204..03eedefcd5 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -32,11 +32,12 @@ def __init__( self, prune_ratio: float, prune_index: int, - ): + ) -> None: super().__init__() self.pct_kept_tokens = (1 - prune_index * prune_ratio) / (1 - (prune_index - 1) * prune_ratio) - def forward(self, x: torch.Tensor, m: torch.Tensor, scores: torch.Tensor): # [B, N, C], [B, 1, 1, N], [B, N] + # [B, N, C], [B, 1, 1, N], [B, N] -> [B, N', C], [B, 1, 1, N'] + def forward(self, x: torch.Tensor, m: torch.Tensor, scores: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: B, N, C = x.shape topk_indices = scores.topk(math.floor(self.pct_kept_tokens * N), sorted=False)[1] # [B, N'] x = x.gather(1, topk_indices.unsqueeze(-1).expand(-1, -1, C)) # [B, N', C] @@ -86,8 +87,8 @@ def __init__( self.proj_drop = nn.Dropout(proj_drop) # m is cumulative over all layers (m = m_i * m_i-1 * ... * m_1) - def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: - x, m = in_tuple # [B, N, C], [B, 1, 1, N] + # [B, N, C], [B, 1, 1, N] -> [B, N, C], [B, 1, 1, N], [B, N] + def forward(self, x: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) @@ -112,7 +113,6 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te #FIXME which pruning mask? - # [B, N] #prune_mask = attn.detach().sum(1).sum(-1) #prune_mask = attn.detach().sum(1).abs().sum(-1) #prune_mask = attn.detach().abs().sum((1, -1)) @@ -184,7 +184,7 @@ def __init__( def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x, m = in_tuple - x_new, m, prune_mask = self.attn((self.norm1(x), m)) + x_new, m, prune_mask = self.attn(self.norm1(x), m) x = x + self.drop_path1(self.ls1(x_new)) x, m = self.token_pruner(x, m, prune_mask) if self.token_pruner else (x, m) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) @@ -201,7 +201,7 @@ def __init__( prune_ratio: Optional[float] = None, *args, **kwargs - ): -> None: + ) -> None: super().__init__( *args, **kwargs, @@ -244,13 +244,13 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = x * m.transpose(1, 3).squeeze(-1) return x - def track_dependency_mask(self, track: bool = True): + def track_dependency_mask(self, track: bool = True) -> None: for block in self.blocks: if block.attn.track_dependency_mask is not track: block.attn.dependency_mask = None block.attn.track_dependency_mask = track - def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = None): + def get_dependency_mask(self, layers: Optional[Union[List[int], Tuple[int]]] = None) -> List[torch.Tensor]: # L' * [B, N, N] # L' * [B, N', N'] result = [] From 10c2c5616016b1f293418101e37723ab3a52fedf Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 21:26:00 -0800 Subject: [PATCH 72/84] Update test_models.py --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 481f741827..cde0fc0386 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -79,7 +79,7 @@ EXCLUDE_FILTERS = ['*enormous*'] NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*'] -EXCLUDE_JIT_FILTERS = ['hiera_*'] +EXCLUDE_JIT_FILTERS = ['hiera_*', 'dependencyvit_*'] TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 TARGET_BWD_SIZE = 128 From a7d3c3b9573f07b2fbc5a3224ce56a502377820d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Mon, 25 Dec 2023 23:38:45 -0800 Subject: [PATCH 73/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 03eedefcd5..5bd7e40fde 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -231,7 +231,6 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_drop(x) x = self.norm_pre(x) B, N, _ = x.shape - #m = torch.ones(B, 1, 1, N).to(x) m = torch.Tensor([1]).to(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x, m = checkpoint_seq(self.blocks, (x, m)) From 17eee05927ac8c46bf9f9f037a2cc9734e207fd0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 26 Dec 2023 15:27:00 -0800 Subject: [PATCH 74/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 5bd7e40fde..773566136d 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -115,11 +115,11 @@ def forward(self, x: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch #prune_mask = attn.detach().sum(1).sum(-1) #prune_mask = attn.detach().sum(1).abs().sum(-1) - #prune_mask = attn.detach().abs().sum((1, -1)) + prune_mask = attn.detach().abs().sum((1, -1)) #prune_mask = attn.sum(1).sum(-1) #prune_mask = attn.sum(1).abs().sum(-1) #prune_mask = attn.abs().sum((1, -1)) - prune_mask = m.reshape(B, N) + #prune_mask = m.reshape(B, N) x = self.proj(x) x = self.proj_drop(x) From 3c0713186342fb77a53b89f94366e04b89b62473 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 27 Dec 2023 03:57:39 -0800 Subject: [PATCH 75/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 773566136d..45c1d3d24d 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -115,11 +115,12 @@ def forward(self, x: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch #prune_mask = attn.detach().sum(1).sum(-1) #prune_mask = attn.detach().sum(1).abs().sum(-1) - prune_mask = attn.detach().abs().sum((1, -1)) + #prune_mask = attn.detach().abs().sum((1, -1)) #prune_mask = attn.sum(1).sum(-1) #prune_mask = attn.sum(1).abs().sum(-1) #prune_mask = attn.abs().sum((1, -1)) #prune_mask = m.reshape(B, N) + purne_mask = m.detach().reshape(B, N) x = self.proj(x) x = self.proj_drop(x) From f32418cefb174943704f823018fcc3e1f8891ded Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 27 Dec 2023 04:42:08 -0800 Subject: [PATCH 76/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 45c1d3d24d..92fd8612a0 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -120,7 +120,7 @@ def forward(self, x: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch #prune_mask = attn.sum(1).abs().sum(-1) #prune_mask = attn.abs().sum((1, -1)) #prune_mask = m.reshape(B, N) - purne_mask = m.detach().reshape(B, N) + prune_mask = m.detach().reshape(B, N) x = self.proj(x) x = self.proj_drop(x) From b7b5073f49013dc167f7d5f20c3025f287470695 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Dec 2024 11:25:57 -0700 Subject: [PATCH 77/84] CPE/PEG --- timm/models/dependencyvit.py | 51 +++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 92fd8612a0..a6711fc851 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -7,6 +7,9 @@ ReversedAttention implementation derived from timm's Vision Transformer implementation +Some guesswork for the architecture, presume final models were based on CPVT-GAP with 1 PEG module + - https://arxiv.org/abs/2102.10882 + Implementation for timm by / Copyright 2023, Fredo Guan """ @@ -67,14 +70,14 @@ def __init__( self.dependency_mask = None self.head_selector_temperature = 0.1 # appendix D.1 - self.head_selector = nn.Linear(dim, num_heads, bias=False) # FIXME is there a bias term? + self.head_selector = nn.Linear(dim, num_heads, bias=False) # TODO ablate bias self.message_controller = Mlp( in_features = dim, hidden_features = int(dim/2), out_features = 1, act_layer = nn.GELU, - bias = False, # FIXME is there a bias term? + bias = False, # TODO ablate bias ) #self.token_pruner = None @@ -139,6 +142,18 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: return x.mul_(self.gamma) if self.inplace else x * self.gamma + +class PositionalEncodingGenerator(nn.Module): + def __init__( + self, + dim: int, + ) -> None: + super().__init__() + self.proj = nn.Conv2d(dim, dim, 3, groups=dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.proj(x) + class DependencyViTBlock(nn.Module): @@ -156,8 +171,14 @@ def __init__( act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp, + nchw_in: bool = False, + use_peg: bool = False, ) -> None: super().__init__() + + self.nchw_in = nchw_in + self.use_peg = use_peg + self.norm1 = norm_layer(dim) self.attn = ReversedAttention( dim=dim, @@ -182,13 +203,22 @@ def __init__( ) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.peg = nn.Identity() + def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: x, m = in_tuple + if self.nchw_in: + B, C, H, W = x.shape + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC x_new, m, prune_mask = self.attn(self.norm1(x), m) x = x + self.drop_path1(self.ls1(x_new)) x, m = self.token_pruner(x, m, prune_mask) if self.token_pruner else (x, m) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + if self.use_peg: + # NCHW out + x = self.peg(x.reshape(B, C, H, W)) return (x, m) @@ -200,6 +230,7 @@ def __init__( self, prune_layers: Optional[Union[List[int], Tuple[int]]] = None, prune_ratio: Optional[float] = None, + cpe_depth: int = 1, *args, **kwargs ) -> None: @@ -212,6 +243,7 @@ def __init__( qkv_bias=False, init_values=1e-6, fc_norm=False, + pos_embed='none', ) if prune_layers is not None: @@ -221,17 +253,26 @@ def __init__( assert max(self.prune_layers) <= len(self.blocks), "1 or more pruned layer indices exceed model depth" assert self.prune_ratio * len(self.prune_layers) < 1, "prune_ratio too big, ensure len(prune_layers) * prune_ratio is less than 1" - self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indicess + self.prune_layers = [x-1 for x in self.prune_layers] # convert counting numbers to nn.Sequential indices for prune_index, layer in enumerate(self.prune_layers, 1): self.blocks[layer].token_pruner = TokenPruner(self.prune_ratio, prune_index) + self.blocks[0].nchw_in = True + for layer_index in range(cpe_depth): + self.blocks[layer_index].use_peg = True + self.blocks[layer_index + 1].nchw_in = True + self.blocks[layer_index].peg = PositionalEncodingGenerator(self.embed_dim) + + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + B, _, H, W = x.shape x = self.patch_embed(x) x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - B, N, _ = x.shape + x = x.reshape(B, -1, *self.patch_embed.dynamic_feat_size((H, W))) # [B, N, C] -> [B, C, H, W] m = torch.Tensor([1]).to(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x, m = checkpoint_seq(self.blocks, (x, m)) @@ -244,6 +285,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = x * m.transpose(1, 3).squeeze(-1) return x + # TODO transition to new method that ViT uses? def track_dependency_mask(self, track: bool = True) -> None: for block in self.blocks: if block.attn.track_dependency_mask is not track: @@ -314,6 +356,7 @@ def dependencyvit_small_patch16_224(pretrained: bool = False, **kwargs) -> Depen model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model +# TODO test how this works, presume pretrain unpruned model, enable pruning during inference @register_model def dependencyvit_lite_tiny_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12, prune_layers=[2, 5, 8, 11], prune_ratio=0.16) From d0d96db97304dc98aaf3364d66de3c1c0493a72d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 7 Dec 2024 11:31:26 -0700 Subject: [PATCH 78/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index a6711fc851..ee6b6d3aa4 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -149,7 +149,7 @@ def __init__( dim: int, ) -> None: super().__init__() - self.proj = nn.Conv2d(dim, dim, 3, groups=dim) + self.proj = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.proj(x) From 4c1f6565ddc2c2e2ba6bf71001c66421a16e1e40 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Dec 2024 06:53:57 -0700 Subject: [PATCH 79/84] wrong dim order --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index ee6b6d3aa4..b35f992498 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -218,7 +218,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) if self.use_peg: # NCHW out - x = self.peg(x.reshape(B, C, H, W)) + x = self.peg(x.transpose(1,2).reshape(B, C, H, W)) return (x, m) From 09e61fdb8404d4aed512f10849e08f97bbc6ea5d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Dec 2024 06:54:05 -0700 Subject: [PATCH 80/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index b35f992498..8619e4d97e 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -218,7 +218,7 @@ def forward(self, in_tuple: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Te x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) if self.use_peg: # NCHW out - x = self.peg(x.transpose(1,2).reshape(B, C, H, W)) + x = self.peg(x.transpose(1, 2).reshape(B, C, H, W)) return (x, m) From a5b0f0d03d4dfeb3f256360c7eb44b4fd42f2460 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 8 Dec 2024 09:31:25 -0700 Subject: [PATCH 81/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 8619e4d97e..4fa8b7ca73 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -272,7 +272,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor: x = self._pos_embed(x) x = self.patch_drop(x) x = self.norm_pre(x) - x = x.reshape(B, -1, *self.patch_embed.dynamic_feat_size((H, W))) # [B, N, C] -> [B, C, H, W] + x = x.transpose(1, 2).reshape(B, -1, *self.patch_embed.dynamic_feat_size((H, W))) # [B, N, C] -> [B, C, H, W] m = torch.Tensor([1]).to(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x, m = checkpoint_seq(self.blocks, (x, m)) From 8c26f3829ac8ded5f105d48c68a521d1c5149369 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 14 Dec 2024 21:15:48 -0700 Subject: [PATCH 82/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 4fa8b7ca73..6734466f9b 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -350,6 +350,13 @@ def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> Depend model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model +@register_model +def dependencyvit_tiny_cpe5_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12, cpe_depth=5) + model = _create_dependencyvit('dependencyvit_tiny_cpe5_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def dependencyvit_small_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12) From 1aa5ddf6274cc73d1146d4a86095100a1f176a15 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 31 Dec 2024 13:34:54 -0800 Subject: [PATCH 83/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 6734466f9b..4f297fde52 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -231,6 +231,7 @@ def __init__( prune_layers: Optional[Union[List[int], Tuple[int]]] = None, prune_ratio: Optional[float] = None, cpe_depth: int = 1, + pos_embed: str = 'none', *args, **kwargs ) -> None: @@ -349,6 +350,12 @@ def dependencyvit_tiny_patch16_224(pretrained: bool = False, **kwargs) -> Depend model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12) model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model + +@register_model +def dependencyvit_tiny_cpe1_lpe_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: + model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=12, pos_embed='learn') + model = _create_dependencyvit('dependencyvit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model @register_model def dependencyvit_tiny_cpe5_patch16_224(pretrained: bool = False, **kwargs) -> DependencyViT: From f63d2e8071666d631909dc5bda553b13daf9dbcc Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 31 Dec 2024 13:36:39 -0800 Subject: [PATCH 84/84] Update dependencyvit.py --- timm/models/dependencyvit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/dependencyvit.py b/timm/models/dependencyvit.py index 4f297fde52..059b78e1d2 100644 --- a/timm/models/dependencyvit.py +++ b/timm/models/dependencyvit.py @@ -244,7 +244,7 @@ def __init__( qkv_bias=False, init_values=1e-6, fc_norm=False, - pos_embed='none', + pos_embed=pos_embed, ) if prune_layers is not None: