From 4adeb6b7e03daa29f5e611c3ae8704d8e11f9320 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Tue, 8 Oct 2024 13:41:12 +0200 Subject: [PATCH] fix: add script to normalize barcodes in DB --- robotoff/cli/main.py | 119 +++++++++++++++++++++++++++++++++++++++++++ robotoff/off.py | 19 +++++++ 2 files changed, 138 insertions(+) diff --git a/robotoff/cli/main.py b/robotoff/cli/main.py index 24db4cba4a..98d9c51847 100644 --- a/robotoff/cli/main.py +++ b/robotoff/cli/main.py @@ -1020,5 +1020,124 @@ def launch_spellcheck_batch_job( ) +@app.command() +def launch_normalize_barcode_job(): + from openfoodfacts.images import generate_image_path + from peewee import fn + + from robotoff.models import ImageModel, Prediction, ProductInsight, db + from robotoff.off import normalize_barcode + from robotoff.utils import get_logger + + logger = get_logger() + logger.info("Starting barcode normalization job") + + with db.connection_context(): + updated = 0 + min_id = 0 + max_id = Prediction.select(fn.MAX(Prediction.id)).scalar() + with db.atomic() as tsx: + while min_id < max_id: + prediction = None + for prediction in ( + Prediction.select() + .where(Prediction.id >= min_id) + .order_by(Prediction.id.asc()) + .limit(10_000) + ): + barcode = normalize_barcode(prediction.barcode) + source_image = ( + generate_image_path( + prediction.barcode, Path(prediction.source_image).stem + ) + if prediction.source_image + else None + ) + is_updated = (barcode != prediction.barcode) or ( + source_image != prediction.source_image + ) + if is_updated: + prediction.barcode = barcode + prediction.source_image = source_image + prediction.save() + updated += 1 + + tsx.commit() + logger.info("Current ID: %s, Updated %d predictions", min_id, updated) + if prediction is not None: + min_id = prediction.id + else: + break + + logger.info("Updated %d predictions", updated) + + updated = 0 + min_id = ProductInsight.select(fn.MIN(ProductInsight.timestamp)).scalar() + max_id = ProductInsight.select(fn.MAX(ProductInsight.timestamp)).scalar() + with db.atomic() as tsx: + while min_id < max_id: + insight = None + for insight in ( + ProductInsight.select() + .where(ProductInsight.timestamp >= min_id) + .order_by(ProductInsight.timestamp.asc()) + .limit(10_000) + ): + barcode = normalize_barcode(insight.barcode) + source_image = generate_image_path( + insight.barcode, Path(insight.source_image).stem + ) + is_updated = (barcode != insight.barcode) or ( + source_image != insight.source_image + ) + if is_updated: + insight.barcode = barcode + insight.source_image = source_image + insight.save() + updated += 1 + + tsx.commit() + logger.info("Current ID: %s, Updated %d insights", min_id, updated) + if insight is not None: + min_id = insight.timestamp + else: + break + + logger.info("Updated %d insights", updated) + + updated = 0 + min_id = ImageModel.select(fn.MIN(ImageModel.id)).scalar() + max_id = ImageModel.select(fn.MAX(ImageModel.id)).scalar() + with db.atomic() as tsx: + while min_id < max_id: + image = None + for image in ( + ImageModel.select() + .where(ImageModel.id >= min_id) + .order_by(ImageModel.id.asc()) + .limit(10_000) + ): + barcode = normalize_barcode(image.barcode) + source_image = generate_image_path( + image.barcode, Path(image.source_image).stem + ) + is_updated = (barcode != image.barcode) or ( + source_image != image.source_image + ) + if is_updated: + image.barcode = barcode + image.source_image = source_image + image.save() + updated += 1 + + tsx.commit() + logger.info("Current ID: %s, Updated %d images", min_id, updated) + if image is not None: + min_id = image.id + else: + break + logger.info("Updated %d images", updated) + + def main() -> None: app() diff --git a/robotoff/off.py b/robotoff/off.py index efbb47b89f..40ce9f6f7d 100644 --- a/robotoff/off.py +++ b/robotoff/off.py @@ -148,6 +148,25 @@ def is_valid_image(product_id: ProductIdentifier, image_id: str) -> bool: return image_id in images +def normalize_barcode(barcode: str) -> str: + """Normalize the barcode. + + First, we remove leading zeros, then we pad the barcode with zeros to + reach 8 digits. + + If the barcode is longer than 8 digits, we pad it to 13 digits. + + :param barcode: the barcode to normalize + :return: the normalized barcode + """ + barcode = barcode.lstrip("0").zfill(8) + + if len(barcode) > 8: + barcode = barcode.zfill(13) + + return barcode + + def off_credentials() -> dict[str, str]: return {"user_id": settings._off_user, "password": settings._off_password}