diff --git a/robotoff/prediction/ocr/product_weight.py b/robotoff/prediction/ocr/product_weight.py index 65a5d37fc9..3fd0cef000 100644 --- a/robotoff/prediction/ocr/product_weight.py +++ b/robotoff/prediction/ocr/product_weight.py @@ -94,7 +94,23 @@ def is_valid_weight(weight_value: str) -> bool: return True -def is_extreme_weight(normalized_value: float, unit: str) -> bool: +def is_extreme_weight( + normalized_value: float, unit: str, count: int | None = None +) -> bool: + """Return True if the weight is extreme, i.e is likely wrongly detected. + + If considered extreme, a prediction won't be generated. + + :param normalized_value: the normalized weight value + :param unit: the normalized weight unit + :param count: the number of items in the pack, if any + :return: True if the weight is extreme, False otherwise + """ + if count is not None and int(count) > 20: + # More than 20 items in a pack is quite unlikely for + # a consumer product + return True + if unit == "g": # weights above 10 kg return normalized_value >= 10000 or normalized_value <= 10 @@ -200,7 +216,7 @@ def process_multi_packaging(match) -> Optional[dict]: normalized_value, normalized_unit = normalize_weight(value, unit) # Check that the weight is not extreme - if is_extreme_weight(normalized_value, normalized_unit): + if is_extreme_weight(normalized_value, normalized_unit, count): return None text = f"{count} x {value} {unit}" diff --git a/tests/unit/prediction/ocr/test_product_weight.py b/tests/unit/prediction/ocr/test_product_weight.py index 96df0b6615..40e1b9a6ac 100644 --- a/tests/unit/prediction/ocr/test_product_weight.py +++ b/tests/unit/prediction/ocr/test_product_weight.py @@ -92,22 +92,23 @@ def test_is_valid_weight(value: str, is_valid: bool): @pytest.mark.parametrize( - "value,unit,expected", + "value,unit,count,expected", [ - (10000, "g", True), - (10000, "ml", True), - (9999, "ml", False), - (9999, "g", False), - (100, "g", False), - (100, "ml", False), - (10, "ml", True), - (3, "ml", True), - (10, "g", True), - (2, "g", True), + (10000, "g", None, True), + (10000, "ml", None, True), + (9999, "ml", None, False), + (9999, "g", None, False), + (100, "g", None, False), + (100, "ml", None, False), + (10, "ml", None, True), + (3, "ml", None, True), + (10, "g", None, True), + (2, "g", None, True), + (200, "g", 21, True), ], ) -def test_is_extreme_weight(value: float, unit: str, expected: bool): - assert is_extreme_weight(value, unit) is expected +def test_is_extreme_weight(value: float, unit: str, count: int | None, expected: bool): + assert is_extreme_weight(value, unit, count) is expected @pytest.mark.parametrize(