Skip to content

Commit 17eee05

Browse files
committed
Update dependencyvit.py
1 parent a7d3c3b commit 17eee05

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

timm/models/dependencyvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,11 @@ def forward(self, x: torch.Tensor, m: torch.Tensor) -> Tuple[torch.Tensor, torch
115115

116116
#prune_mask = attn.detach().sum(1).sum(-1)
117117
#prune_mask = attn.detach().sum(1).abs().sum(-1)
118-
#prune_mask = attn.detach().abs().sum((1, -1))
118+
prune_mask = attn.detach().abs().sum((1, -1))
119119
#prune_mask = attn.sum(1).sum(-1)
120120
#prune_mask = attn.sum(1).abs().sum(-1)
121121
#prune_mask = attn.abs().sum((1, -1))
122-
prune_mask = m.reshape(B, N)
122+
#prune_mask = m.reshape(B, N)
123123

124124
x = self.proj(x)
125125
x = self.proj_drop(x)

0 commit comments

Comments
 (0)