Skip to content
This repository was archived by the owner on Aug 6, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions eval_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,9 @@ def __getitem__(self, idx):
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument('--data_path', default='/path/to/imagenet/', type=str)
parser.add_argument('--num_classes', default=1000, type=int, help='Number of classes in dataset.')
args = parser.parse_args()

utils.init_distributed_mode(args)
Expand All @@ -237,6 +238,6 @@ def __getitem__(self, idx):
print("Features are ready!\nStart the k-NN classification.")
for k in args.nb_knn:
top1, top5 = knn_classifier(train_features, train_labels,
test_features, test_labels, k, args.temperature)
test_features, test_labels, k, args.temperature, num_classes=args.num_classes)
print(f"{k}-NN classifier result: Top1: {top1}, Top5: {top5}")
dist.barrier()
2 changes: 1 addition & 1 deletion main_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_args_parser():
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
parser.add_argument("--local-rank", default=0, type=int, help="Please ignore and do not set this argument.")
return parser


Expand Down
38 changes: 36 additions & 2 deletions vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,25 @@

import torch
import torch.nn as nn

import logging
import os
import warnings
from utils import trunc_normal_

XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
try:
if XFORMERS_ENABLED:
from xformers.ops import memory_efficient_attention, unbind

XFORMERS_AVAILABLE = True
warnings.warn("xFormers is available (Attention)")
else:
warnings.warn("xFormers is disabled (Attention)")
raise ImportError
except ImportError:
XFORMERS_AVAILABLE = False
warnings.warn("xFormers is not available (Attention)")


def drop_path(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
Expand Down Expand Up @@ -90,14 +106,32 @@ def forward(self, x):
x = self.proj(x)
x = self.proj_drop(x)
return x, attn

class MemEffAttention(Attention):
def forward(self, x, attn_bias=None):
if not XFORMERS_AVAILABLE:
if attn_bias is not None:
raise AssertionError("xFormers is required for using nested tensors")
return super().forward(x)

B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)

q, k, v = unbind(qkv, 2)

attn = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = attn.reshape([B, N, C])

x = self.proj(x)
x = self.proj_drop(x)
return x, attn

class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
self.attn = MemEffAttention( #Attention -> MemEffAttention
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
Expand Down