Skip to content

Commit 2ad75e8

Browse files
committed
Fix issue w/ MAP attention mask and no patch_valid
1 parent d7d3538 commit 2ad75e8

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

timm/models/vision_transformer_flex.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def create_attention_mask(
362362
symmetric: bool = True,
363363
q_len: Optional[int] = None,
364364
dtype: torch.dtype = torch.float32,
365-
) -> torch.Tensor:
365+
) -> Optional[torch.Tensor]:
366366
"""Creates an attention mask from patch validity information.
367367
368368
Supports two modes controlled by `symmetric`:
@@ -392,6 +392,9 @@ def create_attention_mask(
392392
Shape is [B, 1, seq_len, seq_len] if symmetric=True,
393393
or [B, 1, q_len, kv_len] if symmetric=False.
394394
"""
395+
if patch_valid is None:
396+
return None
397+
395398
patch_valid = patch_valid.bool() # Ensure boolean type
396399
B, N = patch_valid.shape
397400
kv_len = N # Initial key/value length is the number of patches

0 commit comments

Comments
 (0)