@@ -252,17 +252,17 @@ def _build_mention_mask_from_char_spans(
252252 """Convert character-level mention spans into a token-level mask."""
253253 offset_mapping = batch_dict ["offset_mapping" ] # [B, seq_len, 2]
254254 token_starts = offset_mapping [:, :, 0 ] # [B, seq_len]
255- token_ends = offset_mapping [:, :, 1 ] # [B, seq_len]
255+ token_ends = offset_mapping [:, :, 1 ] # [B, seq_len]
256256
257257 spans_tensor = torch .tensor (
258258 mention_char_spans , dtype = torch .long , device = device
259259 ) # [B, 2]
260260 mention_starts = spans_tensor [:, 0 ].unsqueeze (1 ) # [B, 1]
261- mention_ends = spans_tensor [:, 1 ].unsqueeze (1 ) # [B, 1]
261+ mention_ends = spans_tensor [:, 1 ].unsqueeze (1 ) # [B, 1]
262262
263263 # Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding.
264264 is_special = (token_starts == 0 ) & (token_ends == 0 )
265- overlaps = (token_ends > mention_starts ) & (token_starts < mention_ends )
265+ overlaps = (token_ends > mention_starts ) & (token_starts < mention_ends )
266266 mask = (overlaps & ~ is_special ).float ()
267267 return mask
268268
0 commit comments