Skip to content

Commit

Permalink
fix: normalize barcode in all API routes
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Oct 8, 2024
1 parent 78cd532 commit 7ca87de
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 87 deletions.
32 changes: 21 additions & 11 deletions robotoff/app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import requests
from falcon.media.validators import jsonschema
from openfoodfacts import OCRResult
from openfoodfacts.barcode import normalize_barcode
from openfoodfacts.images import extract_barcode_from_url
from openfoodfacts.ocr import OCRParsingException, OCRResultGenerationException
from openfoodfacts.types import COUNTRY_CODE_TO_NAME, Country
Expand Down Expand Up @@ -148,6 +149,13 @@ def _get_skip_voted_on(
return SkipVotedOn(SkipVotedType.USERNAME, username)


def normalize_req_barcode(barcode: str | None) -> str | None:
"""Normalize the `barcode` parameter provided in the Falcon request."""
if not barcode:
return None
return normalize_barcode(barcode)


###########
# IMPORTANT: remember to update documentation at doc/references/api.yml if you
# change API
Expand All @@ -156,6 +164,7 @@ def _get_skip_voted_on(

class ProductInsightResource:
def on_get(self, req: falcon.Request, resp: falcon.Response, barcode: str):
barcode = normalize_barcode(barcode)
response: JSONType = {}
server_type = get_server_type_from_req(req)
insights = [
Expand Down Expand Up @@ -192,7 +201,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
keep_types: Optional[list[str]] = req.get_param_as_list(
"insight_types", required=False
)
barcode: Optional[str] = req.get_param("barcode")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
annotated: Optional[bool] = req.get_param_as_bool("annotated")
annotation: Optional[int] = req.get_param_as_int("annotation")
value_tag: str = req.get_param("value_tag")
Expand Down Expand Up @@ -399,7 +408,7 @@ def on_post(self, req: falcon.Request, resp: falcon.Response):

class NutritionPredictorResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
barcode = req.get_param("barcode", required=True)
barcode = normalize_barcode(req.get_param("barcode", required=True))
# we transform image IDs to int to be sure to have "raw" image IDs as
# input
image_ids = req.get_param_as_list("image_ids", required=False, transform=int)
Expand Down Expand Up @@ -594,7 +603,7 @@ def on_post(self, req: falcon.Request, resp: falcon.Response):

if "barcode" in media:
# Fetch product from DB
barcode: str = media["barcode"]
barcode: str = normalize_barcode(media["barcode"])
product = get_product(ProductIdentifier(barcode, server_type)) or {}
if not product:
raise falcon.HTTPNotFound(description=f"product {barcode} not found")
Expand Down Expand Up @@ -695,7 +704,7 @@ class ProductLanguagePredictorResource:
def on_get(self, req: falcon.Request, resp: falcon.Response):
"""Predict the languages displayed on the product images, using
`image_lang` predictions as input."""
barcode = req.get_param("barcode", required=True)
barcode = normalize_barcode(req.get_param("barcode", required=True))
server_type = get_server_type_from_req(req)
counts: dict[str, int] = defaultdict(int)
image_ids: list[int] = []
Expand Down Expand Up @@ -808,7 +817,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
model_name: Optional[str] = req.get_param("model_name")
type_: Optional[str] = req.get_param("type")
model_version: Optional[str] = req.get_param("model_version")
barcode: Optional[str] = req.get_param("barcode")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
min_confidence: Optional[float] = req.get_param_as_float("min_confidence")
server_type = get_server_type_from_req(req)

Expand Down Expand Up @@ -959,7 +968,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
"count", min_value=1, max_value=2000, default=25
)
type_: Optional[str] = req.get_param("type")
barcode: Optional[str] = req.get_param("barcode")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
value: Optional[str] = req.get_param("value")
taxonomy_value: Optional[str] = req.get_param("taxonomy_value")
min_confidence: Optional[float] = req.get_param_as_float("min_confidence")
Expand Down Expand Up @@ -1343,6 +1352,7 @@ class ProductQuestionsResource:
"""

def on_get(self, req: falcon.Request, resp: falcon.Response, barcode: str):
barcode = normalize_barcode(barcode)
response: JSONType = {}
count: int = req.get_param_as_int("count", min_value=1, default=25)
lang: str = req.get_param("lang", default="en")
Expand Down Expand Up @@ -1526,7 +1536,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
keep_types: list[str] = req.get_param_as_list(
"insight_types", required=False, default=[]
)[:10]
barcode = req.get_param("barcode")
barcode = normalize_req_barcode(req.get_param("barcode"))
annotated = req.get_param_as_bool("annotated", blank_as_true=False)
value_tag = req.get_param("value_tag")
count = req.get_param_as_int("count", min_value=0, max_value=10_000)
Expand Down Expand Up @@ -1586,7 +1596,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
with_predictions: Optional[bool] = req.get_param_as_bool(
"with_predictions", default=False
)
barcode: Optional[str] = req.get_param("barcode")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
server_type = get_server_type_from_req(req)

get_images_ = functools.partial(
Expand Down Expand Up @@ -1614,7 +1624,7 @@ class PredictionCollection:
def on_get(self, req: falcon.Request, resp: falcon.Response):
page: int = req.get_param_as_int("page", min_value=1, default=1)
count: int = req.get_param_as_int("count", min_value=1, default=25)
barcode: Optional[str] = req.get_param("barcode")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
value_tag: str = req.get_param("value_tag")
keep_types: Optional[list[str]] = req.get_param_as_list("types", required=False)
server_type = get_server_type_from_req(req)
Expand Down Expand Up @@ -1702,7 +1712,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
count: int = req.get_param_as_int("count", min_value=1, default=25)
page: int = req.get_param_as_int("page", min_value=1, default=1)
with_logo: Optional[bool] = req.get_param_as_bool("with_logo", default=False)
barcode: Optional[str] = req.get_param("barcode")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
type: Optional[str] = req.get_param("type")
server_type = get_server_type_from_req(req)

Expand Down Expand Up @@ -1736,7 +1746,7 @@ def on_get(self, req: falcon.Request, resp: falcon.Response):
class LogoAnnotationCollection:
def on_get(self, req: falcon.Request, resp: falcon.Response):
response: JSONType = {}
barcode: Optional[str] = req.get_param("barcode")
barcode: Optional[str] = normalize_req_barcode(req.get_param("barcode"))
server_type = get_server_type_from_req(req)
keep_types: Optional[list[str]] = req.get_param_as_list("types", required=False)
value_tag: str = req.get_param("value_tag")
Expand Down
2 changes: 1 addition & 1 deletion robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,11 +1027,11 @@ def launch_normalize_barcode_job(
launch_insight: bool = True,
launch_image: bool = True,
) -> None:
from openfoodfacts.barcode import normalize_barcode
from openfoodfacts.images import generate_image_path
from peewee import fn

from robotoff.models import ImageModel, Prediction, ProductInsight, db
from robotoff.off import normalize_barcode
from robotoff.utils import get_logger

logger = get_logger()
Expand Down
19 changes: 0 additions & 19 deletions robotoff/off.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,25 +148,6 @@ def is_valid_image(product_id: ProductIdentifier, image_id: str) -> bool:
return image_id in images


def normalize_barcode(barcode: str) -> str:
"""Normalize the barcode.
First, we remove leading zeros, then we pad the barcode with zeros to
reach 8 digits.
If the barcode is longer than 8 digits, we pad it to 13 digits.
:param barcode: the barcode to normalize
:return: the normalized barcode
"""
barcode = barcode.lstrip("0").zfill(8)

if len(barcode) > 8:
barcode = barcode.zfill(13)

return barcode


def off_credentials() -> dict[str, str]:
return {"user_id": settings._off_user, "password": settings._off_password}

Expand Down
Loading

0 comments on commit 7ca87de

Please sign in to comment.