diff --git a/Dockerfile b/Dockerfile index bcfd35ddde..374f16ba7d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,6 @@ ARG PYTHON_VERSION=3.11 FROM python:$PYTHON_VERSION-slim AS python-base RUN apt-get update && \ apt-get install --no-install-suggests --no-install-recommends -y gettext curl build-essential && \ - apt-get install ffmpeg libsm6 libxext6 -y && \ apt-get autoremove --purge && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* diff --git a/poetry.lock b/poetry.lock index 182f4908ff..3433fb8648 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2183,19 +2183,19 @@ files = [ ] [[package]] -name = "opencv-contrib-python-headless" +name = "opencv-python-headless" version = "4.10.0.84" description = "Wrapper package for OpenCV python bindings." optional = false python-versions = ">=3.6" files = [ - {file = "opencv-contrib-python-headless-4.10.0.84.tar.gz", hash = "sha256:6351250db97e1f91f31afdec2436afb1c89594e3da02851e0f01e20ea16bbd9e"}, - {file = "opencv_contrib_python_headless-4.10.0.84-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:be91c6c81e839613c6f3b15755bf71789839289d0e3440fab093e0708516ffcf"}, - {file = "opencv_contrib_python_headless-4.10.0.84-cp37-abi3-macosx_12_0_x86_64.whl", hash = "sha256:252df47a7e1da280cef26ee0ecc1799841015ce3718214634bb15bc22d4cb308"}, - {file = "opencv_contrib_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77eb20ee077ac0955704d391c00639df6063cb67cb62606c07b97d8b635feff6"}, - {file = "opencv_contrib_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89c16eb5f888aee7bf664106e12c423705d29d1b094876b66aa4e33d4e8ec905"}, - {file = "opencv_contrib_python_headless-4.10.0.84-cp37-abi3-win32.whl", hash = "sha256:7581d7ffb7fff953436797dca2dfc5e70e100f721ea18ab84ebf11417ea21d0c"}, - {file = "opencv_contrib_python_headless-4.10.0.84-cp37-abi3-win_amd64.whl", hash = "sha256:660ded6b77b07f875f56065016677bbb6a3abca13903b9320164691a46474a7d"}, + {file = "opencv-python-headless-4.10.0.84.tar.gz", hash = "sha256:f2017c6101d7c2ef8d7bc3b414c37ff7f54d64413a1847d89970b6b7069b4e1a"}, + {file = "opencv_python_headless-4.10.0.84-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:a4f4bcb07d8f8a7704d9c8564c224c8b064c63f430e95b61ac0bffaa374d330e"}, + {file = "opencv_python_headless-4.10.0.84-cp37-abi3-macosx_12_0_x86_64.whl", hash = "sha256:5ae454ebac0eb0a0b932e3406370aaf4212e6a3fdb5038cc86c7aea15a6851da"}, + {file = "opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46071015ff9ab40fccd8a163da0ee14ce9846349f06c6c8c0f2870856ffa45db"}, + {file = "opencv_python_headless-4.10.0.84-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:377d08a7e48a1405b5e84afcbe4798464ce7ee17081c1c23619c8b398ff18295"}, + {file = "opencv_python_headless-4.10.0.84-cp37-abi3-win32.whl", hash = "sha256:9092404b65458ed87ce932f613ffbb1106ed2c843577501e5768912360fc50ec"}, + {file = "opencv_python_headless-4.10.0.84-cp37-abi3-win_amd64.whl", hash = "sha256:afcf28bd1209dd58810d33defb622b325d3cbe49dcd7a43a902982c33e5fad05"}, ] [package.dependencies] @@ -2206,13 +2206,13 @@ numpy = [ [[package]] name = "openfoodfacts" -version = "1.1.5" +version = "2.3.4" description = "Official Python SDK of Open Food Facts" optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "openfoodfacts-1.1.5-py3-none-any.whl", hash = "sha256:5fa86ebc88090f53f8f1c938470a885edd44fa1e58c14a825e8b0ef2a98c9652"}, - {file = "openfoodfacts-1.1.5.tar.gz", hash = "sha256:ea05d2f4acaf684ea0590754b0110a7c2d2686225add593fd405c0b6da9bae70"}, + {file = "openfoodfacts-2.3.4-py3-none-any.whl", hash = "sha256:81d029ba0a7c31d883401408b97102f597d40c78a197efcfdcc0e14c024cbb12"}, + {file = "openfoodfacts-2.3.4.tar.gz", hash = "sha256:5bb88e05eb6cb554251fcc3a36216438dd68e1415d79047b1a6ed43587e9b2b7"}, ] [package.dependencies] @@ -2221,7 +2221,8 @@ requests = ">=2.20.0" tqdm = ">=4.0.0,<5.0.0" [package.extras] -pillow = ["Pillow (>=9.3,<10.4)"] +ml = ["Pillow (>=9.3,<11)", "opencv-python-headless (>4.0.0,<5.0.0)", "tritonclient[grpc] (>2.0.0,<3.0.0)"] +pillow = ["Pillow (>=9.3,<11)"] redis = ["redis[hiredis] (>=5.1.0,<5.2.0)"] [[package]] @@ -4654,4 +4655,4 @@ watchdog = ["watchdog (>=2.3)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "c24e974e990052b282c69f6081fa5d00d77852a1b041390739fd11937c19cf5a" +content-hash = "b9d2783882888459efadaf2c5c410baf14628ef19bc9a6c6af54eccdfef536f8" diff --git a/pyproject.toml b/pyproject.toml index 80bea2aec4..14115dc2ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,9 +72,9 @@ python-redis-lock = "~4.0.0" transformers = "~4.44.2" lark = "~1.1.4" h5py = "~3.8.0" -opencv-contrib-python-headless = "~4.10.0.84" +opencv-python-headless = "~4.10.0.84" toml = "~0.10.2" -openfoodfacts = "1.1.5" +openfoodfacts = "2.3.4" imagehash = "~4.3.1" peewee-migrate = "~1.12.2" diskcache = "~5.6.3" diff --git a/robotoff/app/api.py b/robotoff/app/api.py index fc7ec691d4..a486c449e6 100644 --- a/robotoff/app/api.py +++ b/robotoff/app/api.py @@ -875,6 +875,7 @@ class ImagePredictorResource: def on_get(self, req: falcon.Request, resp: falcon.Response): image_url = req.get_param("image_url", required=True) models: list[str] = req.get_param_as_list("models", required=True) + threshold: float = req.get_param_as_float("threshold", default=0.5) available_object_detection_models = list( ObjectDetectionModel.__members__.keys() @@ -921,14 +922,16 @@ def on_get(self, req: falcon.Request, resp: falcon.Response): model = ObjectDetectionModelRegistry.get( ObjectDetectionModel[model_name] ) - result = model.detect_from_image(image, output_image=output_image) + result = model.detect_from_image( + image, output_image=output_image, threshold=threshold + ) if output_image: boxed_image = cast(Image.Image, result.boxed_image) image_response(boxed_image, resp) return else: - predictions[model_name] = result.to_json() + predictions[model_name] = result.to_list() else: model_enum = ImageClassificationModel[model_name] classifier = image_classifier.ImageClassifier( diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index c6fdfdcf0f..cbca9caf07 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -925,7 +925,7 @@ def import_logos( """ from robotoff.cli import logos from robotoff.models import db - from robotoff.prediction.object_detection import OBJECT_DETECTION_MODEL_VERSION + from robotoff.prediction.object_detection import MODELS_CONFIG from robotoff.utils import get_logger logger = get_logger() @@ -935,9 +935,7 @@ def import_logos( imported = logos.import_logos( data_path, ObjectDetectionModel.universal_logo_detector.value, - OBJECT_DETECTION_MODEL_VERSION[ - ObjectDetectionModel.universal_logo_detector - ], + MODELS_CONFIG[ObjectDetectionModel.universal_logo_detector].model_version, batch_size, server_type, ) diff --git a/robotoff/insights/extraction.py b/robotoff/insights/extraction.py index 1877b5ea4b..44a5f51862 100644 --- a/robotoff/insights/extraction.py +++ b/robotoff/insights/extraction.py @@ -8,7 +8,7 @@ from robotoff.off import get_source_from_url from robotoff.prediction import ocr from robotoff.prediction.object_detection import ( - OBJECT_DETECTION_MODEL_VERSION, + MODELS_CONFIG, ObjectDetectionModelRegistry, ) from robotoff.types import ( @@ -88,13 +88,13 @@ def run_object_detection_model( results = ObjectDetectionModelRegistry.get(model_name).detect_from_image( image, output_image=False, triton_uri=triton_uri, threshold=threshold ) - data = results.to_json() + data = results.to_list() max_confidence = max((item["score"] for item in data), default=None) return ImagePrediction.create( image=image_model, type="object_detection", model_name=model_name.name, - model_version=OBJECT_DETECTION_MODEL_VERSION[model_name], + model_version=MODELS_CONFIG[model_name].model_version, data={"objects": data}, timestamp=timestamp, max_confidence=max_confidence, diff --git a/robotoff/prediction/object_detection/__init__.py b/robotoff/prediction/object_detection/__init__.py index 8e9058d5e6..c95a460658 100644 --- a/robotoff/prediction/object_detection/__init__.py +++ b/robotoff/prediction/object_detection/__init__.py @@ -1,6 +1,2 @@ # flake8: noqa -from .core import ( - OBJECT_DETECTION_MODEL_VERSION, - ObjectDetectionModelRegistry, - ObjectDetectionRawResult, -) +from .core import MODELS_CONFIG, ObjectDetectionModelRegistry, ObjectDetectionResult diff --git a/robotoff/prediction/object_detection/core.py b/robotoff/prediction/object_detection/core.py index d48ae37e60..40d15a54e4 100644 --- a/robotoff/prediction/object_detection/core.py +++ b/robotoff/prediction/object_detection/core.py @@ -1,99 +1,130 @@ import dataclasses -import pathlib +import logging import time -from typing import Optional +from typing import Literal import numpy as np -from cv2 import dnn +from openfoodfacts.ml.object_detection import ObjectDetectionRawResult, ObjectDetector +from openfoodfacts.ml.utils import resize_image from PIL import Image +from pydantic import BaseModel, Field from tritonclient.grpc import service_pb2 from robotoff import settings from robotoff.prediction.object_detection.utils import visualization_utils as vis_util from robotoff.triton import get_triton_inference_stub -from robotoff.types import JSONType, ObjectDetectionModel -from robotoff.utils import get_logger, text_file_iter +from robotoff.types import ObjectDetectionModel from robotoff.utils.image import convert_image_to_array -logger = get_logger(__name__) +logger = logging.getLogger(__name__) -LABEL_NAMES_FILENAME = "labels.txt" +class ModelConfig(BaseModel): + """Configuration of an object detection model.""" -OBJECT_DETECTION_MODEL_VERSION = { - ObjectDetectionModel.nutriscore: "yolo-nutriscore-1.0", - ObjectDetectionModel.nutrition_table: "yolo-nutrition-table-1.0", - ObjectDetectionModel.universal_logo_detector: "tf-universal-logo-detector-1.0", + model_name: str = Field( + ..., + description="The name of the model, it will be used as " + "`model_name` field in `image_prediction` table", + ) + model_version: str = Field( + ..., + description="The version of the model, it will be used as " + "`model_version` field in `image_prediction` table", + ) + triton_version: str = Field( + ..., + description="The version of the model used on Triton Inference Server (eg: `1`)", + ) + triton_model_name: str = Field( + ..., description="The name of the model on Triton Inference Server" + ) + image_size: int = Field( + ..., + description="The size of the image expected by the model. " + "The original image will be resized to this size.", + ) + label_names: list[str] = Field( + ..., + description="The names of the labels used by the model. " + "The order of the labels must match the order of the classes in the model.", + ) + backend: Literal["tf", "yolo"] = Field( + ..., + description="The backend used by the model. It can be either `tf` for " + "Tensorflow models or `yolo` for Ultralytics models. Tensorflow models " + "are deprecated and should be replaced by Ultralytics models.", + ) + + +MODELS_CONFIG = { + ObjectDetectionModel.nutriscore: ModelConfig( + model_name=ObjectDetectionModel.nutriscore.name, + model_version="yolo-nutriscore-1.0", + triton_version="1", + triton_model_name="nutriscore", + image_size=640, + label_names=[ + "nutriscore-a", + "nutriscore-b", + "nutriscore-c", + "nutriscore-d", + "nutriscore-e", + ], + backend="yolo", + ), + ObjectDetectionModel.nutrition_table: ModelConfig( + model_name=ObjectDetectionModel.nutrition_table.name, + model_version="yolo-nutrition-table-1.0", + triton_version="1", + triton_model_name="nutrition_table", + image_size=640, + label_names=["nutrition-table"], + backend="yolo", + ), + ObjectDetectionModel.universal_logo_detector: ModelConfig( + model_name=ObjectDetectionModel.universal_logo_detector.name, + model_version="tf-universal-logo-detector-1.0", + triton_version="1", + triton_model_name="universal_logo_detector", + image_size=1024, + label_names=["NULL", "brand", "label"], + backend="tf", + ), + ObjectDetectionModel.price_tag_detection: ModelConfig( + model_name=ObjectDetectionModel.price_tag_detection.name, + model_version="price-tag-detection-1.0", + triton_version="1", + triton_model_name="price_tag_detection", + image_size=960, + label_names=["price-tag"], + backend="yolo", + ), } -@dataclasses.dataclass -class ObjectDetectionRawResult: - num_detections: int - detection_boxes: np.ndarray - detection_scores: np.ndarray - detection_classes: np.ndarray - label_names: list[str] - detection_masks: Optional[np.ndarray] = None - boxed_image: Optional[Image.Image] = None - - def to_json(self) -> list[JSONType]: - """Convert the detection results to a JSON serializable format.""" - results = [] - for bounding_box, score, label in zip( - self.detection_boxes, self.detection_scores, self.detection_classes - ): - label_int = int(label) - label_str = self.label_names[label_int] - if label_str is not None: - result = { - "bounding_box": tuple(bounding_box.tolist()), # type: ignore - "score": float(score), - "label": label_str, - } - results.append(result) - return results - - -def add_boxes_and_labels(image_array: np.ndarray, raw_result: ObjectDetectionRawResult): +class ObjectDetectionResult(ObjectDetectionRawResult): + boxed_image: Image.Image | None + + +def add_boxes_and_labels(image_array: np.ndarray, result: ObjectDetectionResult): vis_util.visualize_boxes_and_labels_on_image_array( image_array, - raw_result.detection_boxes, - raw_result.detection_classes, - raw_result.detection_scores, - raw_result.label_names, - instance_masks=raw_result.detection_masks, + result.detection_boxes, + result.detection_classes, + result.detection_scores, + result.label_names, + instance_masks=None, use_normalized_coordinates=True, line_thickness=5, ) image_with_boxes = Image.fromarray(image_array) - raw_result.boxed_image = image_with_boxes - - -def resize_image(image: Image.Image, max_size: tuple[int, int]) -> Image.Image: - """Resize an image to fit within the specified dimensions. - - :param image: the input image - :param max_size: the maximum width and height as a tuple - :return: the resized image, or the original image if it fits within the - specified dimensions - """ - width, height = image.size - max_width, max_height = max_size - - if width > max_width or height > max_height: - new_image = image.copy() - new_image.thumbnail((max_width, max_height)) - return new_image - - return image + result.boxed_image = image_with_boxes class RemoteModel: - def __init__(self, name: str, label_names: list[str], backend: str): - self.name: str = name - self.label_names = label_names - self.backend = backend + def __init__(self, config: ModelConfig): + self.config = config def detect_from_image_tf( self, @@ -101,7 +132,7 @@ def detect_from_image_tf( output_image: bool = False, triton_uri: str | None = None, threshold: float = 0.5, - ) -> ObjectDetectionRawResult: + ) -> ObjectDetectionResult: """Run A Tensorflow object detection model on an image. The model must have been trained with the Tensorflow Object Detection @@ -122,7 +153,7 @@ def detect_from_image_tf( image_array = convert_image_to_array(resized_image).astype(np.uint8) grpc_stub = get_triton_inference_stub(triton_uri) request = service_pb2.ModelInferRequest() - request.model_name = self.name + request.model_name = self.config.triton_model_name image_input = service_pb2.ModelInferRequest().InferInputTensor() image_input.name = "inputs" @@ -144,7 +175,9 @@ def detect_from_image_tf( start_time = time.monotonic() response = grpc_stub.ModelInfer(request) logger.debug( - "Inference time for %s: %s", self.name, time.monotonic() - start_time + "Inference time for %s: %s", + self.config.triton_model_name, + time.monotonic() - start_time, ) if len(response.outputs) != 4: @@ -178,13 +211,12 @@ def detect_from_image_tf( detection_boxes = detection_boxes[threshold_mask] detection_classes = detection_classes[threshold_mask] - result = ObjectDetectionRawResult( + result = ObjectDetectionResult( num_detections=len(detection_scores), detection_classes=detection_classes, detection_boxes=detection_boxes, detection_scores=detection_scores, - detection_masks=None, - label_names=self.label_names, + label_names=self.config.label_names, ) if output_image: @@ -195,163 +227,27 @@ def detect_from_image_tf( def detect_from_image_yolo( self, image: Image.Image, - output_image: bool = False, triton_uri: str | None = None, threshold: float = 0.5, - ) -> ObjectDetectionRawResult: + ) -> ObjectDetectionResult: """Run an object detection model on an image. The model must have been trained with Ultralytics library. :param image: the input Pillow image - :param output_image: if True, the image with boxes and labels is - returned in the result :param triton_uri: URI of the Triton Inference Server, defaults to None. If not provided, the default value from settings is used. :threshold: the minimum score for a detection to be considered, defaults to 0.5. :return: the detection result """ - # YoloV8 object detection models expect an image with dimensions - # up to 640x640 - height, width = image.size - # Prepare a square image for inference - max_size = max(height, width) - # We paste the original image into a larger square image, as the model - # expects a 640x640 input. - # We paste it in the upper-left corner, on a black background. - squared_image = Image.new("RGB", (max_size, max_size), color="black") - squared_image.paste(image, (0, 0)) - resized_image = squared_image.resize((640, 640)) - - # As we don't process the original image but a modified version of it, - # we need to compute the scale factor for the x and y axis. - image_ratio = width / height - scale_x: float - scale_y: float - if image_ratio > 1: - scale_x = 640 / image_ratio - scale_y = 640 - else: - scale_x = 640 - scale_y = 640 * image_ratio - - # Preprocess the image and prepare blob for model - image_array = ( - convert_image_to_array(resized_image) - .transpose((2, 0, 1)) - .astype(np.float32) - ) - image_array = image_array / 255.0 - image_array = np.expand_dims(image_array, axis=0) - - grpc_stub = get_triton_inference_stub(triton_uri) - request = service_pb2.ModelInferRequest() - request.model_name = self.name - - image_input = service_pb2.ModelInferRequest().InferInputTensor() - image_input.name = "images" - - image_input.datatype = "FP32" - - image_input.shape.extend([1, 3, 640, 640]) - request.inputs.extend([image_input]) - - output = service_pb2.ModelInferRequest().InferRequestedOutputTensor() - output.name = "output0" - request.outputs.extend([output]) - - request.raw_input_contents.extend([image_array.tobytes()]) - start_time = time.monotonic() - response = grpc_stub.ModelInfer(request) - latency = time.monotonic() - start_time - - logger.debug("Inference time for %s: %s", self.name, latency) - - start_time = time.monotonic() - if len(response.outputs) != 1: - raise Exception(f"expected 1 output, got {len(response.outputs)}") - - if len(response.raw_output_contents) != 1: - raise Exception( - f"expected 1 raw output content, got {len(response.raw_output_contents)}" - ) - - output_index = {output.name: i for i, output in enumerate(response.outputs)} - output = np.frombuffer( - response.raw_output_contents[output_index["output0"]], - dtype=np.float32, - ).reshape((1, len(self.label_names) + 4, -1))[0] - - # output is of shape (num_classes + 4, num_detections) - rows = output.shape[1] - raw_detection_classes = np.zeros(rows, dtype=int) - raw_detection_scores = np.zeros(rows, dtype=np.float32) - raw_detection_boxes = np.zeros((rows, 4), dtype=np.float32) - - for i in range(rows): - classes_scores = output[4:, i] - max_cls_idx = np.argmax(classes_scores) - max_score = classes_scores[max_cls_idx] - if max_score < threshold: - continue - raw_detection_classes[i] = max_cls_idx - raw_detection_scores[i] = max_score - - # The bounding box is in the format (x, y, width, height) in - # relative coordinates - # x and y are the coordinates of the center of the bounding box - bbox_width = output[2, i] - bbox_height = output[3, i] - x_min = output[0, i] - 0.5 * bbox_width - y_min = output[1, i] - 0.5 * bbox_height - x_max = x_min + bbox_width - y_max = y_min + bbox_height - - # We save the bounding box in the format - # (y_min, x_min, y_max, x_max) in relative coordinates - # Scale the bounding boxes back to the original image size - raw_detection_boxes[i, 0] = max(0.0, min(1.0, y_min / scale_y)) - raw_detection_boxes[i, 1] = max(0.0, min(1.0, x_min / scale_x)) - raw_detection_boxes[i, 2] = max(0.0, min(1.0, y_max / scale_y)) - raw_detection_boxes[i, 3] = max(0.0, min(1.0, x_max / scale_x)) - - # Perform NMS (Non Maximum Suppression) - detection_box_indices = dnn.NMSBoxes( - raw_detection_boxes, # type: ignore - raw_detection_scores, # type: ignore - score_threshold=threshold, - # the following values are copied from Ultralytics settings - nms_threshold=0.45, - eta=0.5, - ) - detection_classes = np.zeros(len(detection_box_indices), dtype=int) - detection_scores = np.zeros(len(detection_box_indices), dtype=np.float32) - detection_boxes = np.zeros((len(detection_box_indices), 4), dtype=np.float32) - - for i, idx in enumerate(detection_box_indices): - detection_classes[i] = raw_detection_classes[idx] - detection_scores[i] = raw_detection_scores[idx] - detection_boxes[i] = raw_detection_boxes[idx] - - result = ObjectDetectionRawResult( - num_detections=rows, - detection_classes=detection_classes, - detection_boxes=detection_boxes, - detection_scores=detection_scores, - detection_masks=None, - label_names=self.label_names, - ) - latency = time.monotonic() - start_time - logger.debug("Post-processing time for %s: %s", self.name, latency) - - if output_image: - add_boxes_and_labels( - convert_image_to_array(image).astype(np.uint8), - result, - ) - - return result + triton_uri = triton_uri or settings.TRITON_URI + result = ObjectDetector( + model_name=self.config.triton_model_name, + label_names=self.config.label_names, + image_size=self.config.image_size, + ).detect_from_image(image=image, triton_uri=triton_uri, threshold=threshold) + return ObjectDetectionResult(**dataclasses.asdict(result)) def detect_from_image( self, @@ -359,7 +255,7 @@ def detect_from_image( output_image: bool = False, triton_uri: str | None = None, threshold: float = 0.5, - ) -> ObjectDetectionRawResult: + ) -> ObjectDetectionResult: """Run an object detection model on an image. :param image: the input Pillow image @@ -370,14 +266,24 @@ def detect_from_image( :threshold: the minimum score for a detection to be considered. :return: the detection result """ - if self.backend == "tf": - return self.detect_from_image_tf(image, output_image, triton_uri, threshold) - elif self.backend == "yolo": - return self.detect_from_image_yolo( + if self.config.backend == "tf": + result = self.detect_from_image_tf( image, output_image, triton_uri, threshold ) + + elif self.config.backend == "yolo": + result = self.detect_from_image_yolo( + image=image, triton_uri=triton_uri, threshold=threshold + ) else: - raise ValueError(f"Unknown backend: {self.backend}") + raise ValueError(f"Unknown backend: {self.config.backend}") + + if output_image: + add_boxes_and_labels( + convert_image_to_array(image).astype(np.uint8), + result, + ) + return result class ObjectDetectionModelRegistry: @@ -388,32 +294,13 @@ class ObjectDetectionModelRegistry: def load_all(cls): if cls._loaded: return - for model in ObjectDetectionModel: - model_dir = settings.TRITON_MODELS_DIR / model.name - if not model_dir.exists(): - logger.warning("Model directory %s does not exist", model_dir) - continue - cls.models[model] = cls.load(model, model_dir) + for model, config in MODELS_CONFIG.items(): + cls.models[model] = cls.load(model, config) cls._loaded = True @classmethod - def load(cls, model: ObjectDetectionModel, model_dir: pathlib.Path) -> RemoteModel: - # To keep compatibility with the old models, we temporarily specify - # here the backend to use for each model (Tensorflow Object Detection or YOLO). - # Tensorflow Object Detection models are going to be phased out in - # favor of YOLO models - backend = ( - "yolo" - if model - in (ObjectDetectionModel.nutriscore, ObjectDetectionModel.nutrition_table) - else "tf" - ) - label_names = list(text_file_iter(model_dir / LABEL_NAMES_FILENAME)) - - if backend == "tf": - label_names.insert(0, "NULL") - - remote_model = RemoteModel(model.name, label_names, backend=backend) + def load(cls, model: ObjectDetectionModel, config: ModelConfig) -> RemoteModel: + remote_model = RemoteModel(config) cls.models[model] = remote_model return remote_model diff --git a/robotoff/types.py b/robotoff/types.py index 1338fd10fc..67efe083b6 100644 --- a/robotoff/types.py +++ b/robotoff/types.py @@ -12,10 +12,11 @@ JSONType = dict[str, Any] -class ObjectDetectionModel(enum.Enum): +class ObjectDetectionModel(str, enum.Enum): nutriscore = "nutriscore" - universal_logo_detector = "universal-logo-detector" - nutrition_table = "nutrition-table" + universal_logo_detector = "universal_logo_detector" + nutrition_table = "nutrition_table" + price_tag_detection = "price_tag_detection" class ImageClassificationModel(str, enum.Enum): diff --git a/tests/integration/insights/test_extraction.py b/tests/integration/insights/test_extraction.py index 7a1fec1993..1cb2c912d7 100644 --- a/tests/integration/insights/test_extraction.py +++ b/tests/integration/insights/test_extraction.py @@ -4,10 +4,7 @@ from robotoff.insights.extraction import run_object_detection_model from robotoff.models import ImagePrediction -from robotoff.prediction.object_detection.core import ( - ObjectDetectionRawResult, - RemoteModel, -) +from robotoff.prediction.object_detection.core import ObjectDetectionResult, RemoteModel from robotoff.types import ObjectDetectionModel from ..models_utils import ImageModelFactory, clean_db @@ -22,7 +19,7 @@ def image_model(peewee_db): class FakeNutriscoreModel(RemoteModel): - def __init__(self, raw_result: ObjectDetectionRawResult): + def __init__(self, raw_result: ObjectDetectionResult): self.raw_result = raw_result def detect_from_image( @@ -31,7 +28,7 @@ def detect_from_image( output_image: bool = False, triton_uri: str | None = None, threshold: float = 0.5, - ) -> ObjectDetectionRawResult: + ) -> ObjectDetectionResult: return self.raw_result @@ -52,7 +49,7 @@ def detect_from_image( ], ) def test_run_object_detection_model(mocker, image_model, model_name, label_names): - raw_result = ObjectDetectionRawResult( + result = ObjectDetectionResult( num_detections=1, detection_boxes=np.array([[1, 2, 3, 4]]), detection_scores=np.array([0.8]), @@ -61,11 +58,11 @@ def test_run_object_detection_model(mocker, image_model, model_name, label_names ) mocker.patch( "robotoff.prediction.object_detection.core.ObjectDetectionModelRegistry.get", - return_value=FakeNutriscoreModel(raw_result), + return_value=FakeNutriscoreModel(result), ) image_prediction = run_object_detection_model( - model_name, - None, + model_name=model_name, + image=None, image_model=image_model, threshold=0.1, )