From 4d221427490859c8f8fc8995ff00f0d9dd9b0035 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Tue, 1 Oct 2024 14:44:51 +0200 Subject: [PATCH] fix: remove old object detection models (#1423) Also use '_' as word separator for all model_name for object detection By doing this, we don't have anymore a mismatch between the enum (that cannot have '-' in the name) and the Triton model name. --- models/triton/nutriscore-yolo/labels.txt | 5 -- models/triton/nutrition-table/labels.txt | 4 -- .../labels.txt | 0 .../labels.txt | 0 robotoff/app/api.py | 9 ++-- robotoff/cli/main.py | 6 +-- robotoff/insights/extraction.py | 6 +-- robotoff/insights/importer.py | 3 +- robotoff/prediction/object_detection/core.py | 48 +++++++++---------- robotoff/types.py | 20 -------- robotoff/workers/tasks/import_image.py | 4 +- scripts/insert_image_predictions.py | 2 +- tests/integration/insights/test_extraction.py | 2 +- tests/integration/insights/test_importer.py | 4 +- tests/integration/models_utils.py | 2 +- 15 files changed, 42 insertions(+), 73 deletions(-) delete mode 100644 models/triton/nutriscore-yolo/labels.txt delete mode 100644 models/triton/nutrition-table/labels.txt rename models/triton/{nutrition-table-yolo => nutrition_table}/labels.txt (100%) rename models/triton/{universal-logo-detector => universal_logo_detector}/labels.txt (100%) diff --git a/models/triton/nutriscore-yolo/labels.txt b/models/triton/nutriscore-yolo/labels.txt deleted file mode 100644 index a1d0bf40a0..0000000000 --- a/models/triton/nutriscore-yolo/labels.txt +++ /dev/null @@ -1,5 +0,0 @@ -nutriscore-a -nutriscore-b -nutriscore-c -nutriscore-d -nutriscore-e \ No newline at end of file diff --git a/models/triton/nutrition-table/labels.txt b/models/triton/nutrition-table/labels.txt deleted file mode 100644 index 58305ff874..0000000000 --- a/models/triton/nutrition-table/labels.txt +++ /dev/null @@ -1,4 +0,0 @@ -nutrition-table -nutrition-table-small -nutrition-table-small-energy -nutrition-table-text \ No newline at end of file diff --git a/models/triton/nutrition-table-yolo/labels.txt b/models/triton/nutrition_table/labels.txt similarity index 100% rename from models/triton/nutrition-table-yolo/labels.txt rename to models/triton/nutrition_table/labels.txt diff --git a/models/triton/universal-logo-detector/labels.txt b/models/triton/universal_logo_detector/labels.txt similarity index 100% rename from models/triton/universal-logo-detector/labels.txt rename to models/triton/universal_logo_detector/labels.txt diff --git a/robotoff/app/api.py b/robotoff/app/api.py index e562211f62..388e4ba6bd 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -80,6 +80,7 @@ InsightType, JSONType, NeuralCategoryClassifierModel, + ObjectDetectionModel, PredictionType, ProductIdentifier, ServerType, @@ -859,8 +860,8 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): image_url = req.get_param("image_url", required=True) models: list[str] = req.get_param_as_list("models", required=True) - available_object_detection_models = ( - ObjectDetectionModelRegistry.get_available_models() + available_object_detection_models = list( + ObjectDetectionModel.__members__.keys() ) available_clf_models = list(ImageClassificationModel.__members__.keys()) available_models = available_object_detection_models + available_clf_models @@ -901,7 +902,9 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): for model_name in models: if model_name in available_object_detection_models: - model = ObjectDetectionModelRegistry.get(model_name) + model = ObjectDetectionModelRegistry.get( + ObjectDetectionModel[model_name] + ) result = model.detect_from_image(image, output_image=output_image) if output_image: diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index 47f0e7e11d..a4992a6046 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -541,9 +541,9 @@ def run_object_detection_model( if model_name == ObjectDetectionModel.universal_logo_detector: func: Callable = run_logo_object_detection - elif model_name == ObjectDetectionModel.nutrition_table_yolo: + elif model_name == ObjectDetectionModel.nutrition_table: func = run_nutrition_table_object_detection - elif model_name == ObjectDetectionModel.nutriscore_yolo: + elif model_name == ObjectDetectionModel.nutriscore: func = run_nutriscore_object_detection else: raise ValueError(f"unsupported model: {model_name}") @@ -568,7 +568,7 @@ def run_object_detection_model( JOIN.LEFT_OUTER, on=( (ImagePrediction.image_id == ImageModel.id) - & (ImagePrediction.model_name == model_name.value) + & (ImagePrediction.model_name == model_name.name) ), ) .where( diff --git a/robotoff/insights/extraction.py b/robotoff/insights/extraction.py index 6b7cb72592..1877b5ea4b 100644 --- a/robotoff/insights/extraction.py +++ b/robotoff/insights/extraction.py @@ -77,7 +77,7 @@ def run_object_detection_model( """ if ( existing_image_prediction := ImagePrediction.get_or_none( - image=image_model, model_name=model_name.get_type() + image=image_model, model_name=model_name.name ) ) is not None: if return_null_if_exist: @@ -85,7 +85,7 @@ def run_object_detection_model( return existing_image_prediction timestamp = datetime.datetime.now(datetime.timezone.utc) - results = ObjectDetectionModelRegistry.get(model_name.value).detect_from_image( + results = ObjectDetectionModelRegistry.get(model_name).detect_from_image( image, output_image=False, triton_uri=triton_uri, threshold=threshold ) data = results.to_json() @@ -93,7 +93,7 @@ def run_object_detection_model( return ImagePrediction.create( image=image_model, type="object_detection", - model_name=model_name.get_type(), + model_name=model_name.name, model_version=OBJECT_DETECTION_MODEL_VERSION[model_name], data={"objects": data}, timestamp=timestamp, diff --git a/robotoff/insights/importer.py b/robotoff/insights/importer.py index 317180e7f1..bc4c757904 100644 --- a/robotoff/insights/importer.py +++ b/robotoff/insights/importer.py @@ -1222,8 +1222,7 @@ def get_nutrition_table_predictions( .where( ImageModel.barcode == product_id.barcode, ImageModel.server_type == product_id.server_type.name, - ImagePrediction.model_name - == ObjectDetectionModel.nutrition_table_yolo.get_type(), + ImagePrediction.model_name == ObjectDetectionModel.nutrition_table.name, ImagePrediction.max_confidence >= min_score, ) .tuples() diff --git a/robotoff/prediction/object_detection/core.py b/robotoff/prediction/object_detection/core.py index fd609464f1..d48ae37e60 100644 --- a/robotoff/prediction/object_detection/core.py +++ b/robotoff/prediction/object_detection/core.py @@ -21,10 +21,8 @@ OBJECT_DETECTION_MODEL_VERSION = { - ObjectDetectionModel.nutriscore: "tf-nutriscore-1.0", - ObjectDetectionModel.nutriscore_yolo: "yolo-nutriscore-1.0", - ObjectDetectionModel.nutrition_table: "tf-nutrition-table-1.0", - ObjectDetectionModel.nutrition_table_yolo: "yolo-nutrition-table-1.0", + ObjectDetectionModel.nutriscore: "yolo-nutriscore-1.0", + ObjectDetectionModel.nutrition_table: "yolo-nutrition-table-1.0", ObjectDetectionModel.universal_logo_detector: "tf-universal-logo-detector-1.0", } @@ -383,45 +381,43 @@ def detect_from_image( class ObjectDetectionModelRegistry: - models: dict[str, RemoteModel] = {} + models: dict[ObjectDetectionModel, RemoteModel] = {} _loaded = False - @classmethod - def get_available_models(cls) -> list[str]: - cls.load_all() - return list(cls.models.keys()) - @classmethod def load_all(cls): if cls._loaded: return for model in ObjectDetectionModel: - model_name = model.value - file_path = settings.TRITON_MODELS_DIR / model_name - if file_path.is_dir(): - logger.info("Model %s found", model_name) - cls.models[model_name] = cls.load(model_name, file_path) - else: - logger.info("Missing model: %s", model_name) + model_dir = settings.TRITON_MODELS_DIR / model.name + if not model_dir.exists(): + logger.warning("Model directory %s does not exist", model_dir) + continue + cls.models[model] = cls.load(model, model_dir) cls._loaded = True @classmethod - def load(cls, name: str, model_dir: pathlib.Path) -> RemoteModel: - # To keep compatibility with the old models, we temporarily use the - # model name as a heuristic to determine the backend + def load(cls, model: ObjectDetectionModel, model_dir: pathlib.Path) -> RemoteModel: + # To keep compatibility with the old models, we temporarily specify + # here the backend to use for each model (Tensorflow Object Detection or YOLO). # Tensorflow Object Detection models are going to be phased out in # favor of YOLO models - backend = "yolo" if "yolo" in name else "tf" + backend = ( + "yolo" + if model + in (ObjectDetectionModel.nutriscore, ObjectDetectionModel.nutrition_table) + else "tf" + ) label_names = list(text_file_iter(model_dir / LABEL_NAMES_FILENAME)) if backend == "tf": label_names.insert(0, "NULL") - model = RemoteModel(name, label_names, backend=backend) - cls.models[name] = model - return model + remote_model = RemoteModel(model.name, label_names, backend=backend) + cls.models[model] = remote_model + return remote_model @classmethod - def get(cls, name: str) -> RemoteModel: + def get(cls, model: ObjectDetectionModel) -> RemoteModel: cls.load_all() - return cls.models[name] + return cls.models[model] diff --git a/robotoff/types.py b/robotoff/types.py index 0444f03084..e924def714 100644 --- a/robotoff/types.py +++ b/robotoff/types.py @@ -11,28 +11,8 @@ class ObjectDetectionModel(enum.Enum): nutriscore = "nutriscore" - nutriscore_yolo = "nutriscore-yolo" universal_logo_detector = "universal-logo-detector" nutrition_table = "nutrition-table" - nutrition_table_yolo = "nutrition-table-yolo" - - def get_type(self) -> str: - """This helper function is useful as long as we have two model (yolo - and tf) for each type of detection. - Once we've migrated all models to Yolo, we can remove this function. - """ - if self in ( - ObjectDetectionModel.nutriscore, - ObjectDetectionModel.nutriscore_yolo, - ): - return "nutriscore" - if self in ( - ObjectDetectionModel.nutrition_table, - ObjectDetectionModel.nutrition_table_yolo, - ): - return "nutrition-table" - - return "universal-logo-detector" class ImageClassificationModel(str, enum.Enum): diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index 7473d2e616..1a7ef458d8 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -267,7 +267,7 @@ def run_nutrition_table_object_detection( source_image=source_image, server_type=product_id.server_type.name ): run_object_detection_model( - ObjectDetectionModel.nutrition_table_yolo, + ObjectDetectionModel.nutrition_table, image, image_model, triton_uri=triton_uri, @@ -399,7 +399,7 @@ def run_nutriscore_object_detection( return image_prediction = run_object_detection_model( - ObjectDetectionModel.nutriscore_yolo, + ObjectDetectionModel.nutriscore, image, image_model, triton_uri=triton_uri, diff --git a/scripts/insert_image_predictions.py b/scripts/insert_image_predictions.py index 557c84a53c..f58f2f8271 100644 --- a/scripts/insert_image_predictions.py +++ b/scripts/insert_image_predictions.py @@ -13,7 +13,7 @@ DATA_PATH = settings.DATASET_DIR / "logos-paperspace.jsonl.gz" -MODEL_NAME = "universal-logo-detector" +MODEL_NAME = "universal_logo_detector" MODEL_VERSION = "tf-universal-logo-detector-1.0" TYPE = "object_detection" SERVER_TYPE = ServerType.off diff --git a/tests/integration/insights/test_extraction.py b/tests/integration/insights/test_extraction.py index 2bec5e9a60..7a1fec1993 100644 --- a/tests/integration/insights/test_extraction.py +++ b/tests/integration/insights/test_extraction.py @@ -71,7 +71,7 @@ def test_run_object_detection_model(mocker, image_model, model_name, label_names ) assert isinstance(image_prediction, ImagePrediction) assert image_prediction.type == "object_detection" - assert image_prediction.model_name == model_name.value + assert image_prediction.model_name == model_name.name assert image_prediction.data == { "objects": [ {"bounding_box": (1, 2, 3, 4), "score": 0.8, "label": label_names[1]} diff --git a/tests/integration/insights/test_importer.py b/tests/integration/insights/test_importer.py index 3978a22337..a9cba63e87 100644 --- a/tests/integration/insights/test_importer.py +++ b/tests/integration/insights/test_importer.py @@ -52,8 +52,8 @@ def generate_image_prediction( return ImagePredictionFactory( image=image, data={"objects": objects}, - model_name="nutrition-table", - model_version="tf-nutrition-table-1.0", + model_name="nutrition_table", + model_version="yolo-nutrition-table-1.0", max_confidence=max_confidence, ) diff --git a/tests/integration/models_utils.py b/tests/integration/models_utils.py index ddc9f06425..2918dd740f 100644 --- a/tests/integration/models_utils.py +++ b/tests/integration/models_utils.py @@ -111,7 +111,7 @@ class Meta: model = ImagePrediction type = "object_detection" - model_name = "universal-logo-detector" + model_name = "universal_logo_detector" model_version = "tf-universal-logo-detector-1.0" data = { "objects": [