Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Upgrade nutrition extractor model #1511

Merged
merged 2 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/ml-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ services:
# https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_management.md
# for more information
entrypoint: "/opt/nvidia/nvidia_entrypoint.sh tritonserver --model-repository=/models --model-control-mode=explicit --load-model=*"
mem_limit: 20g
mem_limit: 30g
runtime: nvidia
deploy:
resources:
Expand Down
2 changes: 1 addition & 1 deletion docker/ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ services:
# https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_management.md
# for more information
entrypoint: "tritonserver --model-repository=/models --model-control-mode=explicit --load-model=*"
mem_limit: 20g
mem_limit: 30g

fasttext:
restart: $RESTART_POLICY
Expand Down
7 changes: 6 additions & 1 deletion robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def run_nutrition_extraction(
None,
help="URI of the Triton Inference Server to use. If not provided, the default value from settings is used.",
),
model_version: Optional[str] = typer.Option(
None, help="Version of the model to use, defaults to the latest"
),
) -> None:
"""Run nutrition extraction on a product image.

Expand All @@ -693,7 +696,9 @@ def run_nutrition_extraction(

image = cast(Image.Image, get_image_from_url(image_url))
ocr_result = cast(OCRResult, OCRResult.from_url(image_url.replace(".jpg", ".json")))
prediction = predict(image, ocr_result, triton_uri=triton_uri)
prediction = predict(
image, ocr_result, model_version=model_version, triton_uri=triton_uri
)
if prediction is not None:
pprint(prediction)
else:
Expand Down
16 changes: 9 additions & 7 deletions robotoff/prediction/nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class NutritionExtractionPrediction:
def predict(
image: Image.Image,
ocr_result: OCRResult,
model_version: str = "1",
model_version: str | None = None,
triton_uri: str | None = None,
) -> NutritionExtractionPrediction | None:
"""Predict the nutrient values from an image and an OCR result.
Expand All @@ -77,7 +77,7 @@ def predict(

:param image: the *original* image (not resized)
:param ocr_result: the OCR result
:param model_version: the version of the model to use, defaults to "1"
:param model_version: the version of the model to use, defaults to None (latest)
:param triton_uri: the URI of the Triton Inference Server, if not provided, the
default value from settings is used
:return: a `NutritionExtractionPrediction` object
Expand Down Expand Up @@ -619,7 +619,7 @@ def send_infer_request(
pixel_values: np.ndarray,
model_name: str,
triton_stub: GRPCInferenceServiceStub,
model_version: str = "1",
model_version: str | None = None,
) -> np.ndarray:
"""Send a NER infer request to the Triton inference server.

Expand All @@ -634,7 +634,7 @@ def send_infer_request(
:param pixel_values: pixel values of the image, generated using the
transformers tokenizer.
:param model_name: the name of the model to use
:param model_version: version of the model model to use, defaults to "1"
:param model_version: version of the model model to use, defaults to None (latest).
:return: the predicted logits
"""
request = build_triton_request(
Expand All @@ -660,7 +660,7 @@ def build_triton_request(
bbox: np.ndarray,
pixel_values: np.ndarray,
model_name: str,
model_version: str = "1",
model_version: str | None = None,
):
"""Build a Triton ModelInferRequest gRPC request for LayoutLMv3 models.

Expand All @@ -672,12 +672,14 @@ def build_triton_request(
:param pixel_values: pixel values of the image, generated using the
transformers tokenizer.
:param model_name: the name of the model to use.
:param model_version: version of the model model to use, defaults to "1".
:param model_version: version of the model model to use, defaults to None (latest).
:return: the gRPC ModelInferRequest
"""
request = service_pb2.ModelInferRequest()
request.model_name = model_name
request.model_version = model_version

if model_version:
request.model_version = model_version

add_triton_infer_input_tensor(request, "input_ids", input_ids, "INT64")
add_triton_infer_input_tensor(request, "attention_mask", attention_mask, "INT64")
Expand Down
Loading