Skip to content

Commit a76f44c

Browse files
authored
perf(medcat-v2): optimize hot path allocations and lookups (#401)
* perf: share PerDocumentTokenCache across entities during training Previously a new PerDocumentTokenCache was created per entity inside the training loop, discarding cached token validity checks. For a document with N entities and M tokens this caused N×M validity checks instead of M. Now the cache is created once per document and shared. * perf: use dict lookup for CUI index in TwoStepLinker disambiguation Replace O(n) list.index() call per CUI candidate with O(1) dict lookup. The cui_to_idx dict is built once before the loop. * perf: use bisect for O(log n) token lookup in get_tokens Both regex and spacy Document.get_tokens() previously scanned all tokens linearly to find those within a character range. With bisect on the pre-built char_indices array, lookup is O(log n) instead of O(n). For a 1000-token document with 50 entities this reduces comparisons from ~50,000 to ~500. * perf: use mp.get_context instead of global set_start_method Replace mp.set_start_method("spawn", force=True) which mutates process-wide state on every batch run with mp.get_context("spawn") passed to ProcessPoolExecutor. This avoids silently overriding the start method for other libraries (e.g. PyTorch DataLoaders).
1 parent 8a630ba commit a76f44c

5 files changed

Lines changed: 34 additions & 21 deletions

File tree

medcat-v2/medcat/cat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -482,14 +482,16 @@ def _multiprocess(
482482
saver: Optional[BatchAnnotationSaver],
483483
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
484484
external_processes = n_process - 1
485+
mp_context = None
485486
if self.FORCE_SPAWN_MP:
486487
import multiprocessing as mp
487488
logger.info(
488-
"Forcing multiprocessing start method to 'spawn' "
489+
"Using 'spawn' multiprocessing context "
489490
"due to known compatibility issues with 'fork' and "
490491
"libraries using threads or native extensions.")
491-
mp.set_start_method("spawn", force=True)
492-
with ProcessPoolExecutor(max_workers=external_processes) as executor:
492+
mp_context = mp.get_context("spawn")
493+
with ProcessPoolExecutor(max_workers=external_processes,
494+
mp_context=mp_context) as executor:
493495
while True:
494496
try:
495497
yield from self._mp_one_batch_per_process(

medcat-v2/medcat/components/linking/context_based_linker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,11 @@ def _process_entity_train(self, doc: MutableDocument,
110110
def _train_on_doc(self, doc: MutableDocument,
111111
ner_ents: list[MutableEntity]
112112
) -> Iterator[MutableEntity]:
113-
# Run training
113+
# Run training — share cache across all entities in the document
114+
per_doc_valid_token_cache = PerDocumentTokenCache()
114115
for entity in ner_ents:
115116
yield from self._process_entity_train(
116-
doc, entity, PerDocumentTokenCache())
117+
doc, entity, per_doc_valid_token_cache)
117118

118119
def _process_entity_nt_w_name(
119120
self, doc: MutableDocument,

medcat-v2/medcat/components/linking/two_step_context_based_linker.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ def _do_training(self,
132132
per_doc_valid_token_cache=per_doc_valid_token_cache)
133133

134134
def _train_for_tuis(self, doc: MutableDocument) -> None:
135-
# Run training
135+
# Run training — share cache across all entities in the document
136+
per_doc_valid_token_cache = PerDocumentTokenCache()
136137
for entity in doc.ner_ents:
137138
self._process_entity_train_tuis(
138-
doc, entity, PerDocumentTokenCache())
139+
doc, entity, per_doc_valid_token_cache)
139140

140141
def _check_similarity(self, cui: str, context_similarity: float) -> bool:
141142
th_type = self.config.components.linking.similarity_threshold_type
@@ -284,10 +285,11 @@ def _preprocess_disamb(self, ent: MutableEntity, name: str,
284285
return
285286
per_cui_type_sims = pew[ent]
286287
cnf_2step = self.two_step_config
288+
cui_to_idx = {c: i for i, c in enumerate(cuis)}
287289
for cui, type_sim in per_cui_type_sims.items():
288-
if cui not in cuis:
290+
if cui not in cui_to_idx:
289291
continue
290-
cui_index = cuis.index(cui)
292+
cui_index = cui_to_idx[cui]
291293
cui_sim = similarities[cui_index]
292294
ts_coef = sigmoid(
293295
cnf_2step.alpha_sharpness * (

medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
from typing import cast, Optional, Iterator, overload, Union, Any, Type
33
from collections import defaultdict
4+
from bisect import bisect_left, bisect_right
45
import warnings
56

67
from medcat.tokenizing.tokens import (
@@ -224,6 +225,7 @@ def __init__(self, text: str, tokens: Optional[list[MutableToken]] = None
224225
) -> None:
225226
self.text = text
226227
self._tokens = tokens or []
228+
self._char_indices: list[int] = []
227229
self.ner_ents: list[MutableEntity] = []
228230
self.linked_ents: list[MutableEntity] = []
229231

@@ -256,12 +258,12 @@ def __len__(self) -> int:
256258

257259
def get_tokens(self, start_index: int, end_index: int
258260
) -> list[MutableToken]:
259-
tkns = []
260-
for tkn in self:
261-
if (tkn.base.char_index >= start_index and
262-
tkn.base.char_index <= end_index):
263-
tkns.append(tkn)
264-
return tkns
261+
if self._char_indices:
262+
lo = bisect_left(self._char_indices, start_index)
263+
hi = bisect_right(self._char_indices, end_index)
264+
return self._tokens[lo:hi]
265+
return [tkn for tkn in self
266+
if start_index <= tkn.base.char_index <= end_index]
265267

266268
def __iter__(self) -> Iterator[MutableToken]:
267269
yield from self._tokens
@@ -387,6 +389,7 @@ def __call__(self, text: str) -> MutableDocument:
387389
doc._tokens.append(Token(doc, token, token_w_ws,
388390
start_index, tkn_index,
389391
False, False))
392+
doc._char_indices.append(start_index)
390393
return doc
391394

392395
@classmethod

medcat-v2/medcat/tokenizing/spacy_impl/tokens.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Iterator, Union, Optional, overload, cast, Any
2+
from bisect import bisect_left, bisect_right
23
import logging
34

45
from spacy.tokens import Token as SpacyToken
@@ -196,6 +197,7 @@ class Document:
196197

197198
def __init__(self, delegate: SpacyDoc) -> None:
198199
self._delegate = delegate
200+
self._char_indices: Optional[list[int]] = None
199201
self.ner_ents: list[MutableEntity] = []
200202
self.linked_ents: list[MutableEntity] = []
201203

@@ -225,14 +227,17 @@ def __getitem__(self, index: Union[int, slice]
225227
def __len__(self) -> int:
226228
return len(self._delegate)
227229

230+
def _ensure_char_indices(self) -> list[int]:
231+
if self._char_indices is None:
232+
self._char_indices = [tkn.idx for tkn in self._delegate]
233+
return self._char_indices
234+
228235
def get_tokens(self, start_index: int, end_index: int
229236
) -> list[MutableToken]:
230-
tkns = []
231-
for tkn in self:
232-
if (tkn.base.char_index >= start_index and
233-
tkn.base.char_index <= end_index):
234-
tkns.append(tkn)
235-
return tkns
237+
char_indices = self._ensure_char_indices()
238+
lo = bisect_left(char_indices, start_index)
239+
hi = bisect_right(char_indices, end_index)
240+
return [Token(self._delegate[i]) for i in range(lo, hi)]
236241

237242
def set_addon_data(self, path: str, val: Any) -> None:
238243
if not self._delegate.has_extension(path):

0 commit comments

Comments
 (0)