Skip to content

Commit

Permalink
feat: Improve nutrition extraction (#1484)
Browse files Browse the repository at this point in the history
* fix(nutrisight): improve post-processing

- fix postprocessing bug for addition of unit
- correct OCR error for serving_size when 'g' was mistaken as '9'

* fix: optimize rerun_import_all_images job
  • Loading branch information
raphael0202 authored Dec 5, 2024
1 parent 4228e2f commit 2430741
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 13 deletions.
16 changes: 13 additions & 3 deletions robotoff/prediction/nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,9 @@ def postprocess_aggregated_entities(
return postprocessed_entities


SERVING_SIZE_MISSING_G = re.compile(r"([0-9]+[,.]?[0-9]*)\s*9")


def postprocess_aggregated_entities_single(entity: JSONType) -> JSONType:
"""Postprocess a single aggregated entity and return an entity with the extracted
information. This is the first step in the postprocessing of aggregated entities.
Expand Down Expand Up @@ -466,6 +469,11 @@ def postprocess_aggregated_entities_single(entity: JSONType) -> JSONType:

if entity_label == "serving_size":
value = words_str
# Sometimes the unit 'g' in the `serving_size is detected as a '9'
# In such cases, we replace the '9' with 'g'
match = SERVING_SIZE_MISSING_G.match(value)
if match:
value = f"{match.group(1)} g"
elif words_str in ("trace", "traces"):
value = "traces"
else:
Expand Down Expand Up @@ -549,13 +557,15 @@ def match_nutrient_value(
for target in (
"proteins",
"sugars",
"added-sugars",
"carbohydrates",
"fat",
"saturated-fat",
"fiber",
"salt",
"trans-fat",
# we use "_" here as separator as '-' is only used in
# Product Opener, the label names are all separated by '_'
"saturated_fat",
"added_sugars",
"trans_fat",
)
)
and value.endswith("9")
Expand Down
53 changes: 43 additions & 10 deletions robotoff/workers/tasks/import_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def rerun_import_all_images(
where_clauses.append(ImageModel.server_type == server_type.name)
query = (
ImageModel.select(
ImageModel.barcode, ImageModel.image_id, ImageModel.server_type
ImageModel.id,
ImageModel.barcode,
ImageModel.image_id,
ImageModel.server_type,
)
.where(*where_clauses)
.order_by(ImageModel.uploaded_at.desc())
Expand All @@ -104,18 +107,16 @@ def rerun_import_all_images(
if return_count:
return query.count()

for barcode, image_id, server_type_str in query:
for image_model_id, barcode, image_id, server_type_str in query:
if not isinstance(barcode, str) and not barcode.isdigit():
raise ValueError("Invalid barcode: %s" % barcode)

product_id = ProductIdentifier(barcode, ServerType[server_type_str])
image_url = generate_image_url(product_id, image_id)
ocr_url = generate_json_ocr_url(product_id, image_id)
enqueue_job(
run_import_image_job,
get_high_queue(product_id),
job_kwargs={"result_ttl": 0},
run_import_image(
product_id=product_id,
image_model_id=image_model_id,
image_url=image_url,
ocr_url=ocr_url,
flags=flags,
Expand Down Expand Up @@ -144,16 +145,16 @@ def run_import_image_job(
What tasks are performed can be controlled using the `flags` parameter. By
default, all tasks are performed. A new rq job is enqueued for each task.
Before running the tasks, the image is downloaded and stored in the Robotoff
DB.
:param product_id: the product identifier
:param image_url: the URL of the image to import
:param ocr_url: the URL of the OCR JSON file
:param flags: the list of flags to run, defaults to None (all)
"""
logger.info("Running `import_image` for %s, image %s", product_id, image_url)

if flags is None:
flags = [flag for flag in ImportImageFlag]

source_image = get_source_from_url(image_url)
product = get_product_store(product_id.server_type)[product_id]
if product is None and settings.ENABLE_MONGODB_ACCESS:
Expand Down Expand Up @@ -185,13 +186,45 @@ def run_import_image_job(
ImageModel.bulk_update([image_model], fields=["deleted"])
return

run_import_image(
product_id=product_id,
image_model_id=image_model.id,
image_url=image_url,
ocr_url=ocr_url,
flags=flags,
)


def run_import_image(
product_id: ProductIdentifier,
image_model_id: int,
image_url: str,
ocr_url: str,
flags: list[ImportImageFlag] | None = None,
) -> None:
"""Launch all extraction tasks on an image.
We assume that the image exists in the Robotoff DB.
What tasks are performed can be controlled using the `flags` parameter. By
default, all tasks are performed. A new rq job is enqueued for each task.
:param product_id: the product identifier
:param image_model_id: the DB ID of the image
:param image_url: the URL of the image to import
:param ocr_url: the URL of the OCR JSON file
:param flags: the list of flags to run, defaults to None (all)
"""
if flags is None:
flags = [flag for flag in ImportImageFlag]

if ImportImageFlag.add_image_fingerprint in flags:
# Compute image fingerprint, this job is low priority
enqueue_job(
add_image_fingerprint_job,
low_queue,
job_kwargs={"result_ttl": 0},
image_model_id=image_model.id,
image_model_id=image_model_id,
)

if product_id.server_type.is_food():
Expand Down
57 changes: 57 additions & 0 deletions tests/unit/prediction/test_nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
aggregate_entities,
match_nutrient_value,
postprocess_aggregated_entities,
postprocess_aggregated_entities_single,
)


Expand Down Expand Up @@ -392,8 +393,64 @@ def test_aggregate_entities_multiple_entities(self):
("25.9", "iron_100g", ("25.9", None, True)),
("O g", "salt_100g", ("0", "g", True)),
("O", "salt_100g", ("0", None, True)),
("0,19", "saturated_fat_100g", ("0.1", "g", True)),
],
)
def test_match_nutrient_value(words_str: str, entity_label: str, expected_output):

assert match_nutrient_value(words_str, entity_label) == expected_output


@pytest.mark.parametrize(
"aggregated_entity,expected_output",
[
(
{
"end": 90,
"score": 0.9985358715057373,
"start": 89,
"words": ["0,19\n"],
"entity": "SATURATED_FAT_100G",
"char_end": 459,
"char_start": 454,
},
{
"char_end": 459,
"char_start": 454,
"end": 90,
"entity": "saturated-fat_100g",
"score": 0.9985358715057373,
"start": 89,
"text": "0,19",
"unit": "g",
"valid": True,
"value": "0.1",
},
),
(
{
"end": 92,
"score": 0.9985358715057373,
"start": 90,
"words": ["42.5 9"],
"entity": "SERVING_SIZE",
"char_end": 460,
"char_start": 454,
},
{
"char_end": 460,
"char_start": 454,
"end": 92,
"entity": "serving_size",
"score": 0.9985358715057373,
"start": 90,
"text": "42.5 9",
"unit": None,
"valid": True,
"value": "42.5 g",
},
),
],
)
def test_postprocess_aggregated_entities_single(aggregated_entity, expected_output):
assert postprocess_aggregated_entities_single(aggregated_entity) == expected_output

0 comments on commit 2430741

Please sign in to comment.