Skip to content

Commit 3664576

Browse files
changed cui embedding method and fixed mention_mask generation
1 parent 7ca33eb commit 3664576

1 file changed

Lines changed: 41 additions & 31 deletions

File tree

medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -249,35 +249,27 @@ def _build_mention_mask_from_char_spans(
249249
mention_char_spans: list[tuple[int, int]],
250250
device: torch.device,
251251
) -> Tensor:
252-
"""
253-
Convert character-level mention spans into a token-level mask.
254-
255-
Args:
256-
batch_dict: tokenizer output with 'offset_mapping'
257-
mention_char_spans: list of (start_char, end_char) per example
258-
device: torch device
259-
260-
Returns:
261-
mask: [batch_size, seq_len] float Tensor, 1.0 for mention tokens,
262-
0.0 otherwise
263-
"""
264-
offset_mapping = batch_dict["offset_mapping"] # [B, max_token_length, 2]
265-
batch_size, seq_len, _ = offset_mapping.shape
266-
mask = torch.zeros((batch_size, seq_len), dtype=torch.float32, device=device)
267-
268-
for i, (mention_start, mention_end) in enumerate(mention_char_spans):
269-
# For each token in the sequence
270-
for j in range(seq_len):
271-
token_start, token_end = offset_mapping[i, j].tolist()
272-
# Skip padding tokens
273-
if token_end == 0 and token_start == 0:
274-
continue
275-
# Check if token overlaps mention span
276-
if token_end > mention_start and token_start < mention_end:
277-
mask[i, j] = 1.0
278-
252+
"""Convert character-level mention spans into a token-level mask."""
253+
offset_mapping = batch_dict["offset_mapping"] # [B, seq_len, 2]
254+
token_starts = offset_mapping[:, :, 0] # [B, seq_len]
255+
token_ends = offset_mapping[:, :, 1] # [B, seq_len]
256+
257+
spans_tensor = torch.tensor(
258+
mention_char_spans, dtype=torch.long, device=device
259+
) # [B, 2]
260+
mention_starts = spans_tensor[:, 0].unsqueeze(1) # [B, 1]
261+
mention_ends = spans_tensor[:, 1].unsqueeze(1) # [B, 1]
262+
263+
# Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding.
264+
is_special = (token_starts == 0) & (token_ends == 0)
265+
overlaps = (token_ends > mention_starts) & (token_starts < mention_ends)
266+
mask = (overlaps & ~is_special).float()
279267
return mask
280268

269+
def _full_text_spans(self, texts: list[str]) -> list[tuple[int, int]]:
270+
"""Build mention spans that cover each full text entry."""
271+
return [(0, len(text)) for text in texts]
272+
281273
def embed(
282274
self,
283275
to_embed: list[str],
@@ -301,7 +293,7 @@ def embed(
301293
).to(target_device)
302294

303295
mention_mask = None
304-
if mention_spans is not None:
296+
if self.cnf_l.use_mention_attention and mention_spans is not None:
305297
mention_mask = self._build_mention_mask_from_char_spans(
306298
batch_dict,
307299
mention_spans,
@@ -322,7 +314,11 @@ def embed_cuis(self) -> None:
322314
"""
323315
self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding
324316

325-
cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
317+
# cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
318+
cui_names = [
319+
max(self.cdb.cui2info[cui].get("names"), key=len)
320+
for cui in self._cui_keys
321+
]
326322
total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size)
327323
all_embeddings = []
328324
for names in tqdm(
@@ -332,7 +328,14 @@ def embed_cuis(self) -> None:
332328
):
333329
with torch.no_grad():
334330
names_to_embed = [name.replace(self.separator, " ") for name in names]
335-
embeddings = self.embed(names_to_embed, device=self.device)
331+
mention_spans = None
332+
if self.cnf_l.use_mention_attention:
333+
mention_spans = self._full_text_spans(names_to_embed)
334+
embeddings = self.embed(
335+
names_to_embed,
336+
mention_spans=mention_spans,
337+
device=self.device,
338+
)
336339
all_embeddings.append(embeddings.cpu())
337340

338341
all_embeddings_matrix = torch.cat(all_embeddings, dim=0)
@@ -358,7 +361,14 @@ def embed_names(self) -> None:
358361
names_to_embed = [
359362
name.replace(self.separator, " ") for name in batch_names
360363
]
361-
embeddings = self.embed(names_to_embed, device=self.device)
364+
mention_spans = None
365+
if self.cnf_l.use_mention_attention:
366+
mention_spans = self._full_text_spans(names_to_embed)
367+
embeddings = self.embed(
368+
names_to_embed,
369+
mention_spans=mention_spans,
370+
device=self.device,
371+
)
362372
all_embeddings.append(embeddings.cpu())
363373

364374
all_embeddings_matrix = torch.cat(all_embeddings, dim=0)

0 commit comments

Comments
 (0)