Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix nutrition extraction insight generation #1438

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 10 additions & 26 deletions robotoff/insights/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,33 +1540,17 @@ def generate_candidates(
predictions: list[Prediction],
product_id: ProductIdentifier,
) -> Iterator[ProductInsight]:
# Don't generate candidates if the product already has nutrients
if (
product is not None
and product.nutriments
# If we delete all nutrient values, these computed values are still
# present. We therefore ignore these keys.
and bool(
set(
key
for key in product.nutriments.keys()
if not (
key.startswith("carbon-footprint-from-known-ingredients")
or key.startswith(
"fruits-vegetables-legumes-estimate-from-ingredients"
)
or key.startswith(
"fruits-vegetables-nuts-estimate-from-ingredients"
)
or key.startswith("nova-group")
or key.startswith("nutrition-score-fr")
)
)
)
):
return

for prediction in predictions:
if product is not None and product.nutriments:
current_keys = set(key for key in product.nutriments.keys())
prediction_keys = set(prediction.data["nutrients"].keys())

# If the prediction brings a nutrient value that is missing in
# the product, we generate an insight, otherwise we
# skip it
if not len(prediction_keys - current_keys):
continue

yield ProductInsight(**prediction.to_dict())

@classmethod
Expand Down
11 changes: 10 additions & 1 deletion robotoff/prediction/nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,17 @@ def postprocess_aggregated_entities(
logger.warning("Could not extract nutrient value from %s", words_str)
is_valid = False

if entity["entity"] == "SERVING_SIZE":
entity_label = "serving_size"
else:
# Reformat the nutrient name so that it matches Open Food Facts format
# Ex: "ENERGY_KCAL_100G" -> "energy-kcal_100g"
entity_label = entity["entity"].lower()
entity_base, entity_per = entity_label.rsplit("_", 1)
entity_base = entity_base.replace("_", "-")
entity_label = f"{entity_base}_{entity_per}"
postprocessed_entity = {
"entity": entity["entity"].lower(),
"entity": entity_label,
"text": words_str,
"value": value,
"unit": unit,
Expand Down
2 changes: 1 addition & 1 deletion tests/ml/test_nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_predict(
output_filename = output_url.split("/")[-1]

with (output_dir / output_filename).open("wt") as f:
json.dump(dataclasses.asdict(result), f)
json.dump(dataclasses.asdict(result), f, indent=4)
elif is_output_available:
r = get_asset_from_url(output_url)
assert r is not None
Expand Down
145 changes: 145 additions & 0 deletions tests/unit/insights/test_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ExpirationDateImporter,
InsightImporter,
LabelInsightImporter,
NutrientExtractionImporter,
NutritionImageImporter,
PackagerCodeInsightImporter,
PackagingImporter,
Expand Down Expand Up @@ -1531,6 +1532,150 @@ def generate_candidates_for_image(
)


class TestNutrientExtractionImporter:
def test_generate_candidates_no_nutrient(self):
product = Product({"code": DEFAULT_BARCODE, "nutriments": {}})
data = {
"nutrients": {
"energy-kj_100g": {
"entity": "energy-kj_100g",
"value": "100",
"unit": "kj",
"text": "100 kj",
"start": 0,
"end": 1,
"char_start": 0,
"char_end": 6,
}
}
}
predictions = [
Prediction(
type=PredictionType.nutrient_extraction,
data=data,
barcode=DEFAULT_BARCODE,
source_image=DEFAULT_SOURCE_IMAGE,
predictor="nutrition_extractor",
predictor_version="nutrition_extractor-1.0",
automatic_processing=False,
)
]
candidates = list(
NutrientExtractionImporter.generate_candidates(
product, predictions, DEFAULT_PRODUCT_ID
)
)
assert len(candidates) == 1
candidate = candidates[0]
assert candidate.type == "nutrient_extraction"
assert candidate.barcode == DEFAULT_BARCODE
assert candidate.type == InsightType.nutrient_extraction.name
assert candidate.value_tag is None
assert candidate.data == data
assert candidate.source_image == DEFAULT_SOURCE_IMAGE
assert candidate.automatic_processing is False
assert candidate.predictor == "nutrition_extractor"
assert candidate.predictor_version == "nutrition_extractor-1.0"

def test_generate_candidates_no_new_nutrient(self):
product = Product(
{
"code": DEFAULT_BARCODE,
"nutriments": {
"energy-kj_100g": "100",
"energy-kj_unit": "kJ",
"fat_100g": "10",
"fat_unit": "g",
},
}
)
data = {
"nutrients": {
"energy-kj_100g": {
"entity": "energy-kj_100g",
"value": "100",
"unit": "kj",
"text": "100 kj",
"start": 0,
"end": 2,
"char_start": 0,
"char_end": 6,
}
}
}
predictions = [
Prediction(
type=PredictionType.nutrient_extraction,
data=data,
barcode=DEFAULT_BARCODE,
source_image=DEFAULT_SOURCE_IMAGE,
predictor="nutrition_extractor",
predictor_version="nutrition_extractor-1.0",
automatic_processing=False,
)
]
candidates = list(
NutrientExtractionImporter.generate_candidates(
product, predictions, DEFAULT_PRODUCT_ID
)
)
assert len(candidates) == 0

def test_generate_candidates_new_nutrient(self):
product = Product(
{
"code": DEFAULT_BARCODE,
"nutriments": {
"energy-kj_100g": "100",
"energy-kj_unit": "kJ",
"fat_100g": "10",
"fat_unit": "g",
},
}
)
data = {
"nutrients": {
"energy-kj_100g": {
"entity": "energy-kj_100g",
"value": "100",
"unit": "kj",
"text": "100 kj",
"start": 0,
"end": 2,
"char_start": 0,
"char_end": 6,
},
"saturated-fat_100g": {
"entity": "saturated-fat_100g",
"value": "5",
"unit": "g",
"text": "5 g",
"start": 3,
"end": 4,
"char_start": 7,
"char_end": 10,
},
}
}
predictions = [
Prediction(
type=PredictionType.nutrient_extraction,
data=data,
barcode=DEFAULT_BARCODE,
source_image=DEFAULT_SOURCE_IMAGE,
predictor="nutrition_extractor",
predictor_version="nutrition_extractor-1.0",
automatic_processing=False,
)
]
candidates = list(
NutrientExtractionImporter.generate_candidates(
product, predictions, DEFAULT_PRODUCT_ID
)
)
assert len(candidates) == 1


class TestImportInsightsForProducts:
def test_import_insights_no_element(self, mocker):
get_product_predictions_mock = mocker.patch(
Expand Down
14 changes: 7 additions & 7 deletions tests/unit/prediction/test_nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_postprocess_aggregated_entities_single_entity(self):
]
expected_output = [
{
"entity": "energy_kcal_100g",
"entity": "energy-kcal_100g",
"text": "525 kcal",
"value": "525",
"unit": "kcal",
Expand Down Expand Up @@ -56,7 +56,7 @@ def test_postprocess_aggregated_entities_multiple_entities(self):
]
expected_output = [
{
"entity": "energy_kcal_100g",
"entity": "energy-kcal_100g",
"text": "525 kcal",
"value": "525",
"unit": "kcal",
Expand All @@ -69,7 +69,7 @@ def test_postprocess_aggregated_entities_multiple_entities(self):
"invalid_reason": "multiple_entities",
},
{
"entity": "energy_kcal_100g",
"entity": "energy-kcal_100g",
"text": "126 kcal",
"value": "126",
"unit": "kcal",
Expand All @@ -87,7 +87,7 @@ def test_postprocess_aggregated_entities_multiple_entities(self):
def test_postprocess_aggregated_entities_no_value(self):
aggregated_entities = [
{
"entity": "FAT_SERVING",
"entity": "SATURATED_FAT_SERVING",
"words": ["fat"],
"score": 0.85,
"start": 0,
Expand All @@ -98,7 +98,7 @@ def test_postprocess_aggregated_entities_no_value(self):
]
expected_output = [
{
"entity": "fat_serving",
"entity": "saturated-fat_serving",
"text": "fat",
"value": None,
"unit": None,
Expand Down Expand Up @@ -219,7 +219,7 @@ def test_postprocess_aggregated_entities_merged_kcal_kj(self):
]
expected_output = [
{
"entity": "energy_kj_100g",
"entity": "energy-kj_100g",
"text": "525",
"value": "525",
"unit": "kj",
Expand All @@ -231,7 +231,7 @@ def test_postprocess_aggregated_entities_merged_kcal_kj(self):
"valid": True,
},
{
"entity": "energy_kcal_100g",
"entity": "energy-kcal_100g",
"text": "126 kcal",
"value": "126",
"unit": "kcal",
Expand Down
Loading