diff --git a/medcat-trainer/webapp/api/api/metrics.py b/medcat-trainer/webapp/api/api/metrics.py index 1cdfaf93d..b6e17e44b 100644 --- a/medcat-trainer/webapp/api/api/metrics.py +++ b/medcat-trainer/webapp/api/api/metrics.py @@ -245,7 +245,7 @@ def rename_meta_anns(self, meta_anns2rename=dict(), meta_ann_values2rename=dict( self.annotations = self._annotations() return - def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, tokenizer: TokenizerWrapperBase) -> Dict: + def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT) -> Dict: device = torch.device(config.general.device) # Create a torch device batch_size_eval = config.general.batch_size_eval pad_id = config.model.padding_idx @@ -267,7 +267,7 @@ def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, token with torch.no_grad(): for i in range(num_batches): - x, cpos, attention_mask, y = create_batch_piped_data(data, + x, cpos, _, y = create_batch_piped_data(data, i*batch_size_eval, (i+1)*batch_size_eval, device=device, @@ -283,23 +283,22 @@ def _eval_model(self, model: nn.Module, data: List, config: ConfigMetaCAT, token return predictions def _eval(self, metacat_model, mct_export): - # TODO: Should be moved into g_config = metacat_model.config.general - t_config = metacat_model.config.train - t_config['test_size'] = 0 - t_config['shuffle_data'] = False - t_config['prerequisites'] = {} - t_config['cui_filter'] = {} # Prepare the data - assert metacat_model.tokenizer is not None - data = prepare_from_json(mct_export, g_config['cntx_left'], g_config['cntx_right'], metacat_model.tokenizer, - cui_filter=t_config['cui_filter'], - replace_center=g_config['replace_center'], prerequisites=t_config['prerequisites'], - lowercase=g_config['lowercase']) + data = prepare_from_json( + mct_export, + g_config.cntx_left, + g_config.cntx_right, + metacat_model.mc.tokenizer, + cui_filter={}, + replace_center=g_config.replace_center, + prerequisites={}, + lowercase=g_config.lowercase + ) # Check is the name there - category_name = g_config['category_name'] + category_name = g_config.category_name if category_name not in data: warnings.warn(f"The meta_model {category_name} does not exist in this MedCATtrainer export.", UserWarning) return {category_name: f"{category_name} does not exist"} @@ -307,12 +306,11 @@ def _eval(self, metacat_model, mct_export): data = data[category_name] # We already have everything, just get the data - category_value2id = g_config['category_value2id'] + category_value2id = g_config.category_value2id data, _, _ = encode_category_values(data, existing_category_value2id=category_value2id) logger.info(_) # Run evaluation - assert metacat_model.tokenizer is not None - result = self._eval_model(metacat_model.model, data, config=metacat_model.config, tokenizer=metacat_model.tokenizer) + result = self._eval_model(metacat_model.mc.model, data, config=metacat_model.mc.config) return {'predictions': result, 'meta_values': _} diff --git a/medcat-trainer/webapp/api/api/models.py b/medcat-trainer/webapp/api/api/models.py index d8df36223..e89459752 100644 --- a/medcat-trainer/webapp/api/api/models.py +++ b/medcat-trainer/webapp/api/api/models.py @@ -45,52 +45,72 @@ class ModelPack(models.Model): last_modified_by = models.ForeignKey(settings.AUTH_USER_MODEL, on_delete=models.CASCADE, default=None, null=True) @transaction.atomic - def save(self, *args, **kwargs): + def save(self, *args, skip_load=False, **kwargs): is_new = self._state.adding if is_new: super().save(*args, **kwargs) + if skip_load: + super().save(*args, **kwargs) + return + # Process the model pack logger.info('Loading model pack: %s', self.model_pack) - model_pack_name = str(self.model_pack).replace(".zip", "") + model_pack_path = self.model_pack.path + try: - CAT.attempt_unpack(self.model_pack.path) + CAT.attempt_unpack(model_pack_path) except BadZipFile as exc: # potential for CRC-32 errors in Trainer process - ignore and still use logger.warning(f'Possibly corrupt cdb.dat decompressing {self.model_pack}\nFull Exception: {exc}') - unpacked_model_pack_path = self.model_pack.path.replace('.zip', '') - # attempt to load cdb + + unpacked_model_pack_path = model_pack_path.replace('.zip', '') + unpacked_file_name = self.model_pack.file.name.replace('.zip', '') + + # attempt to load cdb - use absolute paths for file existence checks try: - CAT.load_cdb(unpacked_model_pack_path) + # Check for v2 (directory) first + cdb_path_abs = os.path.join(unpacked_model_pack_path, 'cdb') + if os.path.exists(cdb_path_abs) and os.path.isdir(cdb_path_abs): + # v2: CDB is a directory + cdb_path_rel = os.path.join(unpacked_file_name, 'cdb') + else: + # v1: CDB is a file + cdb_path_abs = os.path.join(unpacked_model_pack_path, 'cdb.dat') + cdb_path_rel = os.path.join(unpacked_file_name, 'cdb.dat') + if not os.path.exists(cdb_path_abs): + raise FileNotFoundError(f'CDB not found in model pack: {unpacked_model_pack_path}') + + # Validate CDB can be loaded + CDB.load(cdb_path_abs) + + # Create ConceptDB model with relative path (to MEDIA_ROOT) concept_db = ConceptDB() - unpacked_file_name = self.model_pack.file.name.replace('.zip', '') - # cdb path for v2 - cdb_path = os.path.join(unpacked_file_name, 'cdb') - if not os.path.exists(cdb_path): - # cdb path for v1 - cdb_path = os.path.join(unpacked_file_name, 'cdb.dat') - concept_db.cdb_file.name = cdb_path + concept_db.cdb_file.name = cdb_path_rel concept_db.name = f'{self.name}_CDB' concept_db.save(skip_load=True) self.concept_db = concept_db except Exception as exc: raise FileNotFoundError(f'Error loading the CDB from this model pack: {self.model_pack.path}') from exc - # Load Vocab, v2 - vocab_path = os.path.join(unpacked_model_pack_path, "vocab") - if not os.path.exists(vocab_path): + # Load Vocab, v2 (directory) or v1 (file) + vocab_path_abs = os.path.join(unpacked_model_pack_path, "vocab") + vocab_path_rel = os.path.join(unpacked_file_name, "vocab") + if not os.path.exists(vocab_path_abs): # v1 - vocab_path = os.path.join(unpacked_model_pack_path, "vocab.dat") - if os.path.exists(vocab_path): - Vocab.load(vocab_path) + vocab_path_abs = os.path.join(unpacked_model_pack_path, "vocab.dat") + vocab_path_rel = os.path.join(unpacked_file_name, "vocab.dat") + if os.path.exists(vocab_path_abs): + Vocab.load(vocab_path_abs) vocab = Vocabulary() - vocab.vocab_file.name = vocab_path.replace(f'{MEDIA_ROOT}/', '') + # Use relative path for saving to model + vocab.vocab_file.name = vocab_path_rel vocab.save(skip_load=True) self.vocab = vocab else: # DeID model packs do not have a vocab.dat file logger.warn('Error loading the Vocab from this model pack - ' - f'if this is a DeID model pack, this is expected: {vocab_path}') + f'if this is a DeID model pack, this is expected: {vocab_path_abs}') # load MetaCATs try: @@ -462,7 +482,7 @@ class Meta: help_text='Use a remote MedCAT service API for document processing instead of local models'\ '(note: interim model training is not supported for remote model service projects)') model_service_url = models.CharField(max_length=500, blank=True, null=True, - help_text='URL of the remote MedCAT service API (e.g., http://medcat-service:8003)') + help_text='URL of the remote MedCAT service API (e.g., http://medcat-service:8000)') def save(self, *args, **kwargs): # If using remote model service, skip local model validation diff --git a/medcat-trainer/webapp/api/api/tests/test_data_utils.py b/medcat-trainer/webapp/api/api/tests/test_data_utils.py index 87d230ee4..a7c5832bb 100644 --- a/medcat-trainer/webapp/api/api/tests/test_data_utils.py +++ b/medcat-trainer/webapp/api/api/tests/test_data_utils.py @@ -172,26 +172,14 @@ def test_upload_projects_export_with_modelpack(self, mock_exists, mock_vocab_loa # Mock all file operations mock_exists.return_value = False mock_load_addons.return_value = [] - # Create a model pack - the save will be mocked to avoid actual file operations + # Create a model pack - use skip_load to avoid file validation from django.core.files.uploadedfile import SimpleUploadedFile modelpack = ModelPack( name='test_modelpack', model_pack=SimpleUploadedFile('test_modelpack.zip', b'fake zip') ) - # Save with mocked file operations - it will fail on file loading but that's ok - try: - modelpack.save() - except (FileNotFoundError, Exception): - # If save fails, create it directly in the database - from django.utils import timezone - ModelPack.objects.filter(name='test_modelpack').delete() - modelpack = ModelPack.objects.create( - name='test_modelpack', - model_pack='test_modelpack.zip' - ) - # Manually set the file field to avoid save() being called again - ModelPack.objects.filter(id=modelpack.id).update(model_pack='test_modelpack.zip') - modelpack.refresh_from_db() + # Save with skip_load=True to skip file validation + modelpack.save(skip_load=True) # Call the function upload_projects_export( diff --git a/medcat-trainer/webapp/api/api/views.py b/medcat-trainer/webapp/api/api/views.py index f9d9e0d36..713938e52 100644 --- a/medcat-trainer/webapp/api/api/views.py +++ b/medcat-trainer/webapp/api/api/views.py @@ -3,6 +3,7 @@ import traceback from smtplib import SMTPException from tempfile import NamedTemporaryFile +from typing import Any from background_task.models import Task, CompletedTask from django.contrib.auth.views import PasswordResetView @@ -781,7 +782,8 @@ def serialize_task(task, state): 'projects': task.verbose_name.split('-')[1].split('_'), 'created_user': task.creator.username, 'create_time': task.run_at.strftime(dt_fmt), - 'status': state + 'error_msg': '\n'.join(task.last_error.split('\n')[-2:]), + 'status': state, } running_reports = [serialize_task(t, 'running') for t in running_metrics_tasks_qs] for r, t in zip(running_reports, running_metrics_tasks_qs): @@ -790,6 +792,8 @@ def serialize_task(task, state): comp_reports = [serialize_task(t, 'complete') for t in completed_metrics_tasks] for comp_task, comp_rep in zip(completed_metrics_tasks, comp_reports): + if comp_task.has_error(): + comp_rep['status'] = 'Failed' pm_obj = ProjectMetrics.objects.filter(report_name_generated=comp_task.verbose_name).first() if pm_obj is not None and pm_obj.report_name is not None: comp_rep['report_name'] = pm_obj.report_name diff --git a/medcat-trainer/webapp/frontend/src/views/MetricsHome.vue b/medcat-trainer/webapp/frontend/src/views/MetricsHome.vue index c4a83fb15..bc584da8e 100644 --- a/medcat-trainer/webapp/frontend/src/views/MetricsHome.vue +++ b/medcat-trainer/webapp/frontend/src/views/MetricsHome.vue @@ -5,7 +5,18 @@ :hover="true" @click:row="loadMetrics" hide-default-footer - :items-per-page="-1"> + :items-per-page="-1" + :row-props="getRowProps"> + +