Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions medcat-trainer/webapp/api/api/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def rename_meta_anns(self, meta_anns2rename=dict(), meta_ann_values2rename=dict(
self.annotations = self._annotations()
return

def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: TokenizerWrapperBase) -> Dict:
def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT) -> Dict:
device = torch.device(config.general.device) # Create a torch device
batch_size_eval = config.general.batch_size_eval
pad_id = config.model.padding_idx
Expand All @@ -267,7 +267,7 @@ def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, token

with torch.no_grad():
for i in range(num_batches):
x, cpos, attention_mask, y = create_batch_piped_data(data,
x, cpos, _, y = create_batch_piped_data(data,
i*batch_size_eval,
(i+1)*batch_size_eval,
device=device,
Expand All @@ -283,36 +283,34 @@ def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, token
return predictions

def _eval(self, metacat_model, mct_export):
# TODO: Should be moved into
g_config = metacat_model.config.general
t_config = metacat_model.config.train
t_config['test_size'] = 0
t_config['shuffle_data'] = False
t_config['prerequisites'] = {}
t_config['cui_filter'] = {}

# Prepare the data
assert metacat_model.tokenizer is not None
data = prepare_from_json(mct_export, g_config['cntx_left'], g_config['cntx_right'], metacat_model.tokenizer,
cui_filter=t_config['cui_filter'],
replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'],
lowercase=g_config['lowercase'])
data = prepare_from_json(
mct_export,
g_config.cntx_left,
g_config.cntx_right,
metacat_model.mc.tokenizer,
cui_filter={},
replace_center=g_config.replace_center,
prerequisites={},
lowercase=g_config.lowercase
)

# Check is the name there
category_name = g_config['category_name']
category_name = g_config.category_name
if category_name not in data:
warnings.warn(f"The meta_model {category_name} does not exist in this MedCATtrainer export.", UserWarning)
return {category_name: f"{category_name} does not exist"}

data = data[category_name]

# We already have everything, just get the data
category_value2id = g_config['category_value2id']
category_value2id = g_config.category_value2id
data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id)
logger.info(_)
# Run evaluation
assert metacat_model.tokenizer is not None
result = self._eval_model(metacat_model.model, data, config=metacat_model.config, tokenizer=metacat_model.tokenizer)
result = self._eval_model(metacat_model.mc.model, data, config=metacat_model.mc.config)

return {'predictions': result, 'meta_values': _}

Expand Down
64 changes: 42 additions & 22 deletions medcat-trainer/webapp/api/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,52 +45,72 @@ class ModelPack(models.Model):
last_modified_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=True)

@transaction.atomic
def save(self, *args, **kwargs):
def save(self, *args, skip_load=False, **kwargs):
is_new = self._state.adding
if is_new:
super().save(*args, **kwargs)

if skip_load:
super().save(*args, **kwargs)
return

# Process the model pack
logger.info('Loading model pack: %s', self.model_pack)
model_pack_name = str(self.model_pack).replace(".zip", "")
model_pack_path = self.model_pack.path

try:
CAT.attempt_unpack(self.model_pack.path)
CAT.attempt_unpack(model_pack_path)
except BadZipFile as exc:
# potential for CRC-32 errors in Trainer process - ignore and still use
logger.warning(f'Possibly corrupt cdb.dat decompressing {self.model_pack}\nFull Exception: {exc}')
unpacked_model_pack_path = self.model_pack.path.replace('.zip', '')
# attempt to load cdb

unpacked_model_pack_path = model_pack_path.replace('.zip', '')
unpacked_file_name = self.model_pack.file.name.replace('.zip', '')

# attempt to load cdb - use absolute paths for file existence checks
try:
CAT.load_cdb(unpacked_model_pack_path)
# Check for v2 (directory) first
cdb_path_abs = os.path.join(unpacked_model_pack_path, 'cdb')
if os.path.exists(cdb_path_abs) and os.path.isdir(cdb_path_abs):
# v2: CDB is a directory
cdb_path_rel = os.path.join(unpacked_file_name, 'cdb')
else:
# v1: CDB is a file
cdb_path_abs = os.path.join(unpacked_model_pack_path, 'cdb.dat')
cdb_path_rel = os.path.join(unpacked_file_name, 'cdb.dat')
if not os.path.exists(cdb_path_abs):
raise FileNotFoundError(f'CDB not found in model pack: {unpacked_model_pack_path}')

# Validate CDB can be loaded
CDB.load(cdb_path_abs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be quite a heavy task. Perhpas there should be a validation method on some of these to validate that MedCAT knows how to load them? Though wouldn't necessarily guarantee anything...

Though obviously not much has change from before. So nothing major.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah agree - so this was purposely done here so it fails early, rather than assuming it just works, and fails when it tries to load on a project load. Ideally, yes there would be somehow to guarantee it loads without actually loading it...


# Create ConceptDB model with relative path (to MEDIA_ROOT)
concept_db = ConceptDB()
unpacked_file_name = self.model_pack.file.name.replace('.zip', '')
# cdb path for v2
cdb_path = os.path.join(unpacked_file_name, 'cdb')
if not os.path.exists(cdb_path):
# cdb path for v1
cdb_path = os.path.join(unpacked_file_name, 'cdb.dat')
concept_db.cdb_file.name = cdb_path
concept_db.cdb_file.name = cdb_path_rel
concept_db.name = f'{self.name}_CDB'
concept_db.save(skip_load=True)
self.concept_db = concept_db
except Exception as exc:
raise FileNotFoundError(f'Error loading the CDB from this model pack: {self.model_pack.path}') from exc

# Load Vocab, v2
vocab_path = os.path.join(unpacked_model_pack_path, "vocab")
if not os.path.exists(vocab_path):
# Load Vocab, v2 (directory) or v1 (file)
vocab_path_abs = os.path.join(unpacked_model_pack_path, "vocab")
vocab_path_rel = os.path.join(unpacked_file_name, "vocab")
if not os.path.exists(vocab_path_abs):
# v1
vocab_path = os.path.join(unpacked_model_pack_path, "vocab.dat")
if os.path.exists(vocab_path):
Vocab.load(vocab_path)
vocab_path_abs = os.path.join(unpacked_model_pack_path, "vocab.dat")
vocab_path_rel = os.path.join(unpacked_file_name, "vocab.dat")
if os.path.exists(vocab_path_abs):
Vocab.load(vocab_path_abs)
vocab = Vocabulary()
vocab.vocab_file.name = vocab_path.replace(f'{MEDIA_ROOT}/', '')
# Use relative path for saving to model
vocab.vocab_file.name = vocab_path_rel
vocab.save(skip_load=True)
self.vocab = vocab
else:
# DeID model packs do not have a vocab.dat file
logger.warn('Error loading the Vocab from this model pack - '
f'if this is a DeID model pack, this is expected: {vocab_path}')
f'if this is a DeID model pack, this is expected: {vocab_path_abs}')

# load MetaCATs
try:
Expand Down Expand Up @@ -462,7 +482,7 @@ class Meta:
help_text='Use a remote MedCAT service API for document processing instead of local models'\
'(note: interim model training is not supported for remote model service projects)')
model_service_url = models.CharField(max_length=500, blank=True, null=True,
help_text='URL of the remote MedCAT service API (e.g., http://medcat-service:8003)')
help_text='URL of the remote MedCAT service API (e.g., http://medcat-service:8000)')

def save(self, *args, **kwargs):
# If using remote model service, skip local model validation
Expand Down
18 changes: 3 additions & 15 deletions medcat-trainer/webapp/api/api/tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,26 +172,14 @@ def test_upload_projects_export_with_modelpack(self, mock_exists, mock_vocab_loa
# Mock all file operations
mock_exists.return_value = False
mock_load_addons.return_value = []
# Create a model pack - the save will be mocked to avoid actual file operations
# Create a model pack - use skip_load to avoid file validation
from django.core.files.uploadedfile import SimpleUploadedFile
modelpack = ModelPack(
name='test_modelpack',
model_pack=SimpleUploadedFile('test_modelpack.zip', b'fake zip')
)
# Save with mocked file operations - it will fail on file loading but that's ok
try:
modelpack.save()
except (FileNotFoundError, Exception):
# If save fails, create it directly in the database
from django.utils import timezone
ModelPack.objects.filter(name='test_modelpack').delete()
modelpack = ModelPack.objects.create(
name='test_modelpack',
model_pack='test_modelpack.zip'
)
# Manually set the file field to avoid save() being called again
ModelPack.objects.filter(id=modelpack.id).update(model_pack='test_modelpack.zip')
modelpack.refresh_from_db()
# Save with skip_load=True to skip file validation
modelpack.save(skip_load=True)

# Call the function
upload_projects_export(
Expand Down
6 changes: 5 additions & 1 deletion medcat-trainer/webapp/api/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import traceback
from smtplib import SMTPException
from tempfile import NamedTemporaryFile
from typing import Any

from background_task.models import Task, CompletedTask
from django.contrib.auth.views import PasswordResetView
Expand Down Expand Up @@ -781,7 +782,8 @@ def serialize_task(task, state):
'projects': task.verbose_name.split('-')[1].split('_'),
'created_user': task.creator.username,
'create_time': task.run_at.strftime(dt_fmt),
'status': state
'error_msg': '\n'.join(task.last_error.split('\n')[-2:]),
'status': state,
}
running_reports = [serialize_task(t, 'running') for t in running_metrics_tasks_qs]
for r, t in zip(running_reports, running_metrics_tasks_qs):
Expand All @@ -790,6 +792,8 @@ def serialize_task(task, state):

comp_reports = [serialize_task(t, 'complete') for t in completed_metrics_tasks]
for comp_task, comp_rep in zip(completed_metrics_tasks, comp_reports):
if comp_task.has_error():
comp_rep['status'] = 'Failed'
pm_obj = ProjectMetrics.objects.filter(report_name_generated=comp_task.verbose_name).first()
if pm_obj is not None and pm_obj.report_name is not None:
comp_rep['report_name'] = pm_obj.report_name
Expand Down
68 changes: 66 additions & 2 deletions medcat-trainer/webapp/frontend/src/views/MetricsHome.vue
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,18 @@
:hover="true"
@click:row="loadMetrics"
hide-default-footer
:items-per-page="-1">
:items-per-page="-1"
:row-props="getRowProps">
<template #item.report_id="{ item }">
<v-tooltip v-if="item.status === 'Failed' && item.error_msg" location="top">
<template #activator="{ props }">
<span v-bind="props" class="failed-row-tooltip-trigger">{{ item.report_id }}</span>
</template>
<span>{{ item.error_msg }}</span>
</v-tooltip>
<span v-else>{{ item.report_id }}</span>
</template>

<template #item.projects="{ item }" >
<div @click.stop>
<v-runtime-template :template="projectsFormatter(item.projects)"></v-runtime-template>
Expand All @@ -22,6 +33,17 @@
<span v-if="item.status === 'complete'">Complete
<font-awesome-icon icon="check" class="status-icon success"></font-awesome-icon>
</span>
<v-tooltip v-if="item.status === 'Failed' && item.error_msg" location="top">
<template #activator="{ props }">
<span v-bind="props" class="failed-row-tooltip-trigger">Failed
<font-awesome-icon icon="times-circle" class="status-icon error"></font-awesome-icon>
</span>
</template>
<span>{{ item.error_msg }}</span>
</v-tooltip>
<span v-else-if="item.status === 'Failed'">Failed
<font-awesome-icon icon="times-circle" class="status-icon error"></font-awesome-icon>
</span>
</template>
<template #item.cleanup="{ item }">
<div @click.stop>
Expand Down Expand Up @@ -50,6 +72,10 @@
Confirm complete metrics report deletion. To regenerate the report reselect the projects on the
home screen and select metrics.
</div>
<div v-if="confDeleteReportModal.status === 'Failed'">
Confirm failed metrics report deletion. To regenerate the report reselect the projects on the
home screen and select metrics.
</div>
</template>
<template #footer>
<button class="btn btn-danger" @click="confRemoval">Confirm</button>
Expand Down Expand Up @@ -116,13 +142,24 @@ export default {
})
},
loadMetrics (_, { item }) {
// Don't navigate if the report has failed
if (item.status === 'Failed') {
return
}
this.$router.push({
name: 'metrics',
params: {
reportId: item.report_id
}
})
},
getRowProps (data) {
// Apply disabled-row class to failed reports
if (data.item.status === 'Failed') {
return { class: 'disabled-row' }
}
return { class: '' }
},
confRemoval () {
const item = this.confDeleteReportModal
this.$http.delete(`/api/metrics-job/${this.confDeleteReportModal.report_id}/`).then(_ => {
Expand Down Expand Up @@ -166,8 +203,35 @@ export default {
padding-left: 3px;
}

.status-icon.error {
color: $danger;
}

:deep(.v-table > .v-table__wrapper > table > tbody > tr.disabled-row) {
pointer-events: none;
opacity: 0.6;
cursor: not-allowed !important;
}

:deep(.v-table > .v-table__wrapper > table > tbody > tr.disabled-row:hover) {
background-color: inherit !important;
}

:deep(.v-table > .v-table__wrapper > table > tbody > tr.disabled-row > td) {
pointer-events: none;
}

:deep(.v-table > .v-table__wrapper > table > tbody > tr.disabled-row > td > div) {
pointer-events: auto;
}

:deep(.v-table > .v-table__wrapper > table > tbody > tr.disabled-row .failed-row-tooltip-trigger) {
pointer-events: auto;
cursor: help;
}

.project-links {
color: #005EB8;
color: $primary-alt;

&:hover {
color: #fff;
Expand Down
8 changes: 5 additions & 3 deletions v1/medcat/medcat/utils/meta_cat/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,11 @@ def find_alternate_classname(category_value2id: Dict, category_values: Set, alte
failed_to_find = True
if failed_to_find:
raise Exception("The classes set in the config are not the same as the one found in the data. "
"The classes present in the config vs the ones found in the data - "
f"{set(category_value2id.keys())}, {category_values}. Additionally, ensure the "
"populate the 'alternative_class_names' attribute to accommodate for variations.")
f"The classes present in the config: {set(category_value2id.keys())} vs "
f"the ones found in the data: {category_values}. Additionally, ensure to "
"populate the 'alternative_class_names' attribute to accommodate for variations. "
" This also could be due to the data not including enough examples of a given class"
" after train / test splitting.")
category_value2id = copy.deepcopy(updated_category_value2id)
logger.info("Updated categoryvalue2id mapping - %s", category_value2id)
return category_value2id
Expand Down