@@ -249,35 +249,27 @@ def _build_mention_mask_from_char_spans(
249249 mention_char_spans : list [tuple [int , int ]],
250250 device : torch .device ,
251251 ) -> Tensor :
252- """
253- Convert character-level mention spans into a token-level mask.
254-
255- Args:
256- batch_dict: tokenizer output with 'offset_mapping'
257- mention_char_spans: list of (start_char, end_char) per example
258- device: torch device
259-
260- Returns:
261- mask: [batch_size, seq_len] float Tensor, 1.0 for mention tokens,
262- 0.0 otherwise
263- """
264- offset_mapping = batch_dict ["offset_mapping" ] # [B, max_token_length, 2]
265- batch_size , seq_len , _ = offset_mapping .shape
266- mask = torch .zeros ((batch_size , seq_len ), dtype = torch .float32 , device = device )
267-
268- for i , (mention_start , mention_end ) in enumerate (mention_char_spans ):
269- # For each token in the sequence
270- for j in range (seq_len ):
271- token_start , token_end = offset_mapping [i , j ].tolist ()
272- # Skip padding tokens
273- if token_end == 0 and token_start == 0 :
274- continue
275- # Check if token overlaps mention span
276- if token_end > mention_start and token_start < mention_end :
277- mask [i , j ] = 1.0
278-
252+ """Convert character-level mention spans into a token-level mask."""
253+ offset_mapping = batch_dict ["offset_mapping" ] # [B, seq_len, 2]
254+ token_starts = offset_mapping [:, :, 0 ] # [B, seq_len]
255+ token_ends = offset_mapping [:, :, 1 ] # [B, seq_len]
256+
257+ spans_tensor = torch .tensor (
258+ mention_char_spans , dtype = torch .long , device = device
259+ ) # [B, 2]
260+ mention_starts = spans_tensor [:, 0 ].unsqueeze (1 ) # [B, 1]
261+ mention_ends = spans_tensor [:, 1 ].unsqueeze (1 ) # [B, 1]
262+
263+ # Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding.
264+ is_special = (token_starts == 0 ) & (token_ends == 0 )
265+ overlaps = (token_ends > mention_starts ) & (token_starts < mention_ends )
266+ mask = (overlaps & ~ is_special ).float ()
279267 return mask
280268
269+ def _full_text_spans (self , texts : list [str ]) -> list [tuple [int , int ]]:
270+ """Build mention spans that cover each full text entry."""
271+ return [(0 , len (text )) for text in texts ]
272+
281273 def embed (
282274 self ,
283275 to_embed : list [str ],
@@ -301,7 +293,7 @@ def embed(
301293 ).to (target_device )
302294
303295 mention_mask = None
304- if mention_spans is not None :
296+ if self . cnf_l . use_mention_attention and mention_spans is not None :
305297 mention_mask = self ._build_mention_mask_from_char_spans (
306298 batch_dict ,
307299 mention_spans ,
@@ -322,7 +314,11 @@ def embed_cuis(self) -> None:
322314 """
323315 self ._refresh_cdb_keys () # ensure _cui_keys is up to date before embedding
324316
325- cui_names = [self .cdb .get_name (cui ) for cui in self ._cui_keys ]
317+ # cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
318+ cui_names = [
319+ max (self .cdb .cui2info [cui ].get ("names" ), key = len )
320+ for cui in self ._cui_keys
321+ ]
326322 total_batches = math .ceil (len (cui_names ) / self .cnf_l .embedding_batch_size )
327323 all_embeddings = []
328324 for names in tqdm (
@@ -332,7 +328,14 @@ def embed_cuis(self) -> None:
332328 ):
333329 with torch .no_grad ():
334330 names_to_embed = [name .replace (self .separator , " " ) for name in names ]
335- embeddings = self .embed (names_to_embed , device = self .device )
331+ mention_spans = None
332+ if self .cnf_l .use_mention_attention :
333+ mention_spans = self ._full_text_spans (names_to_embed )
334+ embeddings = self .embed (
335+ names_to_embed ,
336+ mention_spans = mention_spans ,
337+ device = self .device ,
338+ )
336339 all_embeddings .append (embeddings .cpu ())
337340
338341 all_embeddings_matrix = torch .cat (all_embeddings , dim = 0 )
@@ -358,7 +361,14 @@ def embed_names(self) -> None:
358361 names_to_embed = [
359362 name .replace (self .separator , " " ) for name in batch_names
360363 ]
361- embeddings = self .embed (names_to_embed , device = self .device )
364+ mention_spans = None
365+ if self .cnf_l .use_mention_attention :
366+ mention_spans = self ._full_text_spans (names_to_embed )
367+ embeddings = self .embed (
368+ names_to_embed ,
369+ mention_spans = mention_spans ,
370+ device = self .device ,
371+ )
362372 all_embeddings .append (embeddings .cpu ())
363373
364374 all_embeddings_matrix = torch .cat (all_embeddings , dim = 0 )
0 commit comments