Skip to content

Commit ed2a619

Browse files
bug(medcat): CU-869cunfx7 Fix supervised training order of operations issue (#408)
* CU-869cunfx7: Add tests for exception catching * CU-869cunfx7: Update test for better flow * CU-869cunfx7: Remove unwarranted deprecation warnings from tokenizers * CU-869cunfx7: Add deprecation warning to pipeline where it belongs * CU-869cunfx7: Remove unused import * CU-869cunfx7: Fix issue with exception raising introduced in release 2.7 / PR 374 --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent a76f44c commit ed2a619

6 files changed

Lines changed: 85 additions & 54 deletions

File tree

medcat-v2/medcat/pipeline/pipeline.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,6 @@ def create_entity(self, doc: MutableDocument,
4444
doc, token_start_index, token_end_index, label)
4545

4646
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
47-
warnings.warn(
48-
"The `medcat.pipeline.pipeline.entity_from_tokens` method is"
49-
"depreacated and subject to removal in a future release. Please "
50-
"use `medcat.pipeline.pipeline.entity_from_tokens_in_doc` instead.",
51-
DeprecationWarning,
52-
stacklevel=2
53-
)
5447
return self.tokenizer.entity_from_tokens(tokens)
5548

5649
def entity_from_tokens_in_doc(
@@ -352,6 +345,14 @@ def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
352345
Returns:
353346
MutableEntity: The resulting entity.
354347
"""
348+
warnings.warn(
349+
"The `medcat.pipeline.pipeline.Pipeline.entity_from_tokens` method is"
350+
"depreacated is subject to removal in a future release. Please use "
351+
"`medcat.pipeline.pipeline.Pipeline.entity_from_tokens_in_doc` "
352+
"instead.",
353+
DeprecationWarning,
354+
stacklevel=2
355+
)
355356
return self._tokenizer.entity_from_tokens(tokens)
356357

357358
def entity_from_tokens_in_doc(self, tokens: list[MutableToken],

medcat-v2/medcat/tokenizing/regex_impl/tokenizer.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import cast, Optional, Iterator, overload, Union, Any, Type
33
from collections import defaultdict
44
from bisect import bisect_left, bisect_right
5-
import warnings
65

76
from medcat.tokenizing.tokens import (
87
BaseToken, BaseEntity, BaseDocument,
@@ -343,14 +342,6 @@ def create_entity(self, doc: MutableDocument,
343342
# return Entity(span)
344343

345344
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
346-
warnings.warn(
347-
"The `medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens` method is"
348-
"depreacated and subject to removal in a future release. Please use "
349-
"`medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens_in_doc` "
350-
"instead.",
351-
DeprecationWarning,
352-
stacklevel=2
353-
)
354345
if not tokens:
355346
raise ValueError("Need at least one token for an entity")
356347
doc = cast(Token, tokens[0])._doc

medcat-v2/medcat/tokenizing/spacy_impl/tokenizers.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import shutil
55
import logging
6-
import warnings
76

87
import spacy
98
from spacy.tokens import Span
@@ -78,14 +77,6 @@ def create_entity(self, doc: MutableDocument,
7877
return Entity(span)
7978

8079
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
81-
warnings.warn(
82-
"The `medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens` method is"
83-
"depreacated and subject to removal in a future release. Please use "
84-
"`medcat.tokenizing.tokenizers.Tokenizer.entity_from_tokens_in_doc` "
85-
"instead.",
86-
DeprecationWarning,
87-
stacklevel=2
88-
)
8980
if not tokens:
9081
raise ValueError("Need at least one token for an entity")
9182
spacy_tokens = cast(list[Token], tokens)

medcat-v2/medcat/tokenizing/tokenizers.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,18 @@ def create_entity(self, doc: MutableDocument,
3434
pass
3535

3636
def entity_from_tokens(self, tokens: list[MutableToken]) -> MutableEntity:
37-
"""Deprecated: use entity_from_tokens_in_doc instead."""
37+
"""Get an entity from the list of tokens.
38+
39+
This will create a new instance instead of looking for existing entity.
40+
This method should be used only if/when there was no existing entity
41+
within the specified document for the given span of tokens.
42+
43+
Args:
44+
tokens (list[MutableToken]): List of tokens.
45+
46+
Returns:
47+
MutableEntity: The resulting entity.
48+
"""
3849
pass
3950

4051
def entity_from_tokens_in_doc(self, tokens: list[MutableToken],

medcat-v2/medcat/trainer.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -405,19 +405,53 @@ def _train_supervised_for_project(self,
405405
devalue_others)
406406

407407
def _prepare_doc_with_anns(
408-
self, doc: MutableDocument,
408+
self, doc: MutableDocument, ann_doc: MedCATTrainerExportDocument,
409409
anns: list[MedCATTrainerExportAnnotation]) -> None:
410410
ents = []
411411
for ann in anns:
412412
tkns = doc.get_tokens(ann['start'], ann['end'])
413-
ents.append(self._pipeline.entity_from_tokens_in_doc(tkns, doc))
413+
try:
414+
ents.append(self._pipeline.entity_from_tokens_in_doc(tkns, doc))
415+
except ValueError as err:
416+
self._warn_on_error(
417+
err, doc.base.text,
418+
(ann['cui'], ann['value'], ann['start'], ann['end']),
419+
(None, ann_doc['id'], ann_doc['name']))
414420
# set NER ents
415421
doc.ner_ents.clear()
416422
doc.ner_ents.extend(ents)
417423
# duplicate for linked as well, but in a a separate list
418424
doc.linked_ents.clear()
419425
doc.linked_ents.extend(ents)
420426

427+
def _warn_on_error(self, ve: BaseException, cur_text: str,
428+
mut_context_start: tuple[str, str, int, int],
429+
mut_context_end: tuple[MutableEntity | None, str, str]):
430+
start, end = mut_context_start[2:]
431+
context_window = 20 # characters
432+
splitter_left, splitter_right = "<", ">"
433+
context_start = max(start - context_window, 0)
434+
context_end = min(end + context_window, len(cur_text) - 1)
435+
context = (cur_text[context_start: start] +
436+
splitter_left +
437+
cur_text[start: end] +
438+
splitter_right +
439+
cur_text[end: context_end])
440+
if context_start > 0:
441+
context = "[...]" + context
442+
if context_end < len(cur_text) - 1:
443+
context += "[...]"
444+
msg_template = (
445+
"Failed to identify '%s' (%s) ([%d:%d]) "
446+
"in '%s' %s within document %s | %s, "
447+
"skipping training for this example")
448+
msg_context = (
449+
*mut_context_start, context, *mut_context_end)
450+
if self.strict_train:
451+
raise ValueError(msg_template % msg_context) from ve
452+
else:
453+
logger.warning(msg_template, *msg_context, exc_info=ve)
454+
# 480+ project
421455
def _train_supervised_for_project2(self,
422456
docs: list[MedCATTrainerExportDocument],
423457
current_document: int,
@@ -433,7 +467,7 @@ def _train_supervised_for_project2(self,
433467
with temp_changed_config(self.config.components.linking,
434468
'train', False):
435469
mut_doc = self.caller(doc['text'])
436-
self._prepare_doc_with_anns(mut_doc, doc['annotations'])
470+
self._prepare_doc_with_anns(mut_doc, doc, doc['annotations'])
437471

438472
# Compatibility with old output where annotations are a list
439473
for ann, mut_entity in zip(doc['annotations'], mut_doc.linked_ents):
@@ -461,31 +495,10 @@ def _train_supervised_for_project2(self,
461495
mut_entity=mut_entity, negative=deleted,
462496
devalue_others=devalue_others)
463497
except (ValueError, KeyError) as ve:
464-
context_window = 20 # characters
465-
splitter_left, splitter_right = "<", ">"
466-
cur_text = doc['text']
467-
context_start = max(start - context_window, 0)
468-
context_end = min(end + context_window, len(cur_text) - 1)
469-
context = (cur_text[context_start: start] +
470-
splitter_left +
471-
cur_text[start: end] +
472-
splitter_right +
473-
cur_text[end: context_end])
474-
if context_start > 0:
475-
context = "[...]" + context
476-
if context_end < len(cur_text) - 1:
477-
context += "[...]"
478-
msg_template = (
479-
"Failed to identify '%s' (%s) ([%d:%d]) "
480-
"in '%s' %s within document %s | %s, "
481-
"skipping training for this example")
482-
msg_context = (
483-
cui, ann['value'], ann['start'], ann['end'],
484-
context, mut_entity, doc['id'], doc['name'])
485-
if self.strict_train:
486-
raise ValueError(msg_template % msg_context) from ve
487-
else:
488-
logger.warning(msg_template, *msg_context, exc_info=ve)
498+
self._warn_on_error(
499+
ve, doc['text'],
500+
(cui, ann['value'], ann['start'], ann['end']),
501+
(mut_entity, doc['id'], doc['name']))
489502
if train_from_false_positives:
490503
fps: list[MutableEntity] = get_false_positives(doc, mut_doc)
491504

medcat-v2/tests/test_trainer.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,30 @@ def test_training_happens_on_linked_ents_on_doc(self):
252252
doc, ent = args.kwargs['mut_doc'], args.kwargs['mut_entity']
253253
self.assertIn(ent, doc.linked_ents)
254254

255+
def test_empty_token_annotation_is_skipped_when_not_strict(self):
256+
self.trainer.strict_train = False
257+
with unittest.mock.patch.object(
258+
self.trainer._pipeline, "entity_from_tokens_in_doc",
259+
side_effect=ValueError("No tokens found in span")), \
260+
unittest.mock.patch.object(
261+
FakeMutDoc, "get_tokens", return_value=[]), \
262+
unittest.mock.patch.object(self.trainer, "add_and_train_concept"):
263+
try:
264+
self.train(self.TRAIN_DATA)
265+
except ValueError as err:
266+
self.fail(f"Unexpected ValueError for empty-token annotation: {err}")
267+
268+
def test_empty_token_annotation_raises_when_strict(self):
269+
self.trainer.strict_train = True
270+
with unittest.mock.patch.object(
271+
self.trainer._pipeline, "entity_from_tokens_in_doc",
272+
side_effect=ValueError("No tokens found in span")), \
273+
unittest.mock.patch.object(
274+
FakeMutDoc, "get_tokens", return_value=[]), \
275+
unittest.mock.patch.object(self.trainer, "add_and_train_concept"):
276+
with self.assertRaises(ValueError):
277+
self.train(self.TRAIN_DATA)
278+
255279

256280
class FromSratchBase(TrainedModelTests):
257281
RNG_SEED = 42

0 commit comments

Comments
 (0)