Skip to content

Commit

Permalink
fix: remove old object detection models (#1423)
Browse files Browse the repository at this point in the history
Also use '_' as word separator for all model_name for object detection
By doing this, we don't have anymore a mismatch between the enum
(that cannot have '-' in the name) and the Triton model name.
  • Loading branch information
raphael0202 authored Oct 1, 2024
1 parent 137c9e3 commit 4d22142
Show file tree
Hide file tree
Showing 15 changed files with 42 additions and 73 deletions.
5 changes: 0 additions & 5 deletions models/triton/nutriscore-yolo/labels.txt

This file was deleted.

4 changes: 0 additions & 4 deletions models/triton/nutrition-table/labels.txt

This file was deleted.

File renamed without changes.
9 changes: 6 additions & 3 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
InsightType,
JSONType,
NeuralCategoryClassifierModel,
ObjectDetectionModel,
PredictionType,
ProductIdentifier,
ServerType,
Expand Down Expand Up @@ -859,8 +860,8 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
image_url = req.get_param("image_url", required=True)
models: list[str] = req.get_param_as_list("models", required=True)

available_object_detection_models = (
ObjectDetectionModelRegistry.get_available_models()
available_object_detection_models = list(
ObjectDetectionModel.__members__.keys()
)
available_clf_models = list(ImageClassificationModel.__members__.keys())
available_models = available_object_detection_models + available_clf_models
Expand Down Expand Up @@ -901,7 +902,9 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):

for model_name in models:
if model_name in available_object_detection_models:
model = ObjectDetectionModelRegistry.get(model_name)
model = ObjectDetectionModelRegistry.get(
ObjectDetectionModel[model_name]
)
result = model.detect_from_image(image, output_image=output_image)

if output_image:
Expand Down
6 changes: 3 additions & 3 deletions robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,9 @@ def run_object_detection_model(

if model_name == ObjectDetectionModel.universal_logo_detector:
func: Callable = run_logo_object_detection
elif model_name == ObjectDetectionModel.nutrition_table_yolo:
elif model_name == ObjectDetectionModel.nutrition_table:
func = run_nutrition_table_object_detection
elif model_name == ObjectDetectionModel.nutriscore_yolo:
elif model_name == ObjectDetectionModel.nutriscore:
func = run_nutriscore_object_detection
else:
raise ValueError(f"unsupported model: {model_name}")
Expand All @@ -568,7 +568,7 @@ def run_object_detection_model(
JOIN.LEFT_OUTER,
on=(
(ImagePrediction.image_id == ImageModel.id)
& (ImagePrediction.model_name == model_name.value)
& (ImagePrediction.model_name == model_name.name)
),
)
.where(
Expand Down
6 changes: 3 additions & 3 deletions robotoff/insights/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,23 +77,23 @@ def run_object_detection_model(
"""
if (
existing_image_prediction := ImagePrediction.get_or_none(
image=image_model, model_name=model_name.get_type()
image=image_model, model_name=model_name.name
)
) is not None:
if return_null_if_exist:
return None
return existing_image_prediction

timestamp = datetime.datetime.now(datetime.timezone.utc)
results = ObjectDetectionModelRegistry.get(model_name.value).detect_from_image(
results = ObjectDetectionModelRegistry.get(model_name).detect_from_image(
image, output_image=False, triton_uri=triton_uri, threshold=threshold
)
data = results.to_json()
max_confidence = max((item["score"] for item in data), default=None)
return ImagePrediction.create(
image=image_model,
type="object_detection",
model_name=model_name.get_type(),
model_name=model_name.name,
model_version=OBJECT_DETECTION_MODEL_VERSION[model_name],
data={"objects": data},
timestamp=timestamp,
Expand Down
3 changes: 1 addition & 2 deletions robotoff/insights/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,8 +1222,7 @@ def get_nutrition_table_predictions(
.where(
ImageModel.barcode == product_id.barcode,
ImageModel.server_type == product_id.server_type.name,
ImagePrediction.model_name
== ObjectDetectionModel.nutrition_table_yolo.get_type(),
ImagePrediction.model_name == ObjectDetectionModel.nutrition_table.name,
ImagePrediction.max_confidence >= min_score,
)
.tuples()
Expand Down
48 changes: 22 additions & 26 deletions robotoff/prediction/object_detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,8 @@


OBJECT_DETECTION_MODEL_VERSION = {
ObjectDetectionModel.nutriscore: "tf-nutriscore-1.0",
ObjectDetectionModel.nutriscore_yolo: "yolo-nutriscore-1.0",
ObjectDetectionModel.nutrition_table: "tf-nutrition-table-1.0",
ObjectDetectionModel.nutrition_table_yolo: "yolo-nutrition-table-1.0",
ObjectDetectionModel.nutriscore: "yolo-nutriscore-1.0",
ObjectDetectionModel.nutrition_table: "yolo-nutrition-table-1.0",
ObjectDetectionModel.universal_logo_detector: "tf-universal-logo-detector-1.0",
}

Expand Down Expand Up @@ -383,45 +381,43 @@ def detect_from_image(


class ObjectDetectionModelRegistry:
models: dict[str, RemoteModel] = {}
models: dict[ObjectDetectionModel, RemoteModel] = {}
_loaded = False

@classmethod
def get_available_models(cls) -> list[str]:
cls.load_all()
return list(cls.models.keys())

@classmethod
def load_all(cls):
if cls._loaded:
return
for model in ObjectDetectionModel:
model_name = model.value
file_path = settings.TRITON_MODELS_DIR / model_name
if file_path.is_dir():
logger.info("Model %s found", model_name)
cls.models[model_name] = cls.load(model_name, file_path)
else:
logger.info("Missing model: %s", model_name)
model_dir = settings.TRITON_MODELS_DIR / model.name
if not model_dir.exists():
logger.warning("Model directory %s does not exist", model_dir)
continue
cls.models[model] = cls.load(model, model_dir)
cls._loaded = True

@classmethod
def load(cls, name: str, model_dir: pathlib.Path) -> RemoteModel:
# To keep compatibility with the old models, we temporarily use the
# model name as a heuristic to determine the backend
def load(cls, model: ObjectDetectionModel, model_dir: pathlib.Path) -> RemoteModel:
# To keep compatibility with the old models, we temporarily specify
# here the backend to use for each model (Tensorflow Object Detection or YOLO).
# Tensorflow Object Detection models are going to be phased out in
# favor of YOLO models
backend = "yolo" if "yolo" in name else "tf"
backend = (
"yolo"
if model
in (ObjectDetectionModel.nutriscore, ObjectDetectionModel.nutrition_table)
else "tf"
)
label_names = list(text_file_iter(model_dir / LABEL_NAMES_FILENAME))

if backend == "tf":
label_names.insert(0, "NULL")

model = RemoteModel(name, label_names, backend=backend)
cls.models[name] = model
return model
remote_model = RemoteModel(model.name, label_names, backend=backend)
cls.models[model] = remote_model
return remote_model

@classmethod
def get(cls, name: str) -> RemoteModel:
def get(cls, model: ObjectDetectionModel) -> RemoteModel:
cls.load_all()
return cls.models[name]
return cls.models[model]
20 changes: 0 additions & 20 deletions robotoff/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,8 @@

class ObjectDetectionModel(enum.Enum):
nutriscore = "nutriscore"
nutriscore_yolo = "nutriscore-yolo"
universal_logo_detector = "universal-logo-detector"
nutrition_table = "nutrition-table"
nutrition_table_yolo = "nutrition-table-yolo"

def get_type(self) -> str:
"""This helper function is useful as long as we have two model (yolo
and tf) for each type of detection.
Once we've migrated all models to Yolo, we can remove this function.
"""
if self in (
ObjectDetectionModel.nutriscore,
ObjectDetectionModel.nutriscore_yolo,
):
return "nutriscore"
if self in (
ObjectDetectionModel.nutrition_table,
ObjectDetectionModel.nutrition_table_yolo,
):
return "nutrition-table"

return "universal-logo-detector"


class ImageClassificationModel(str, enum.Enum):
Expand Down
4 changes: 2 additions & 2 deletions robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def run_nutrition_table_object_detection(
source_image=source_image, server_type=product_id.server_type.name
):
run_object_detection_model(
ObjectDetectionModel.nutrition_table_yolo,
ObjectDetectionModel.nutrition_table,
image,
image_model,
triton_uri=triton_uri,
Expand Down Expand Up @@ -399,7 +399,7 @@ def run_nutriscore_object_detection(
return

image_prediction = run_object_detection_model(
ObjectDetectionModel.nutriscore_yolo,
ObjectDetectionModel.nutriscore,
image,
image_model,
triton_uri=triton_uri,
Expand Down
2 changes: 1 addition & 1 deletion scripts/insert_image_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


DATA_PATH = settings.DATASET_DIR / "logos-paperspace.jsonl.gz"
MODEL_NAME = "universal-logo-detector"
MODEL_NAME = "universal_logo_detector"
MODEL_VERSION = "tf-universal-logo-detector-1.0"
TYPE = "object_detection"
SERVER_TYPE = ServerType.off
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/insights/test_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_run_object_detection_model(mocker, image_model, model_name, label_names
)
assert isinstance(image_prediction, ImagePrediction)
assert image_prediction.type == "object_detection"
assert image_prediction.model_name == model_name.value
assert image_prediction.model_name == model_name.name
assert image_prediction.data == {
"objects": [
{"bounding_box": (1, 2, 3, 4), "score": 0.8, "label": label_names[1]}
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/insights/test_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def generate_image_prediction(
return ImagePredictionFactory(
image=image,
data={"objects": objects},
model_name="nutrition-table",
model_version="tf-nutrition-table-1.0",
model_name="nutrition_table",
model_version="yolo-nutrition-table-1.0",
max_confidence=max_confidence,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integration/models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class Meta:
model = ImagePrediction

type = "object_detection"
model_name = "universal-logo-detector"
model_name = "universal_logo_detector"
model_version = "tf-universal-logo-detector-1.0"
data = {
"objects": [
Expand Down

0 comments on commit 4d22142

Please sign in to comment.