Skip to content

Commit

Permalink
feat: Upgrade nutrition extractor model (#1511)
Browse files Browse the repository at this point in the history
* feat: use latest model by default for nutrition extractor

* chore: increase allowed memory for Triton to 30GB
  • Loading branch information
raphael0202 authored Dec 26, 2024
1 parent 6e4c127 commit 4f13fbb
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docker/ml-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ services:
# https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_management.md
# for more information
entrypoint: "/opt/nvidia/nvidia_entrypoint.sh tritonserver --model-repository=/models --model-control-mode=explicit --load-model=*"
mem_limit: 20g
mem_limit: 30g
runtime: nvidia
deploy:
resources:
Expand Down
2 changes: 1 addition & 1 deletion docker/ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ services:
# https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_management.md
# for more information
entrypoint: "tritonserver --model-repository=/models --model-control-mode=explicit --load-model=*"
mem_limit: 20g
mem_limit: 30g

fasttext:
restart: $RESTART_POLICY
Expand Down
7 changes: 6 additions & 1 deletion robotoff/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def run_nutrition_extraction(
None,
help="URI of the Triton Inference Server to use. If not provided, the default value from settings is used.",
),
model_version: Optional[str] = typer.Option(
None, help="Version of the model to use, defaults to the latest"
),
) -> None:
"""Run nutrition extraction on a product image.
Expand All @@ -693,7 +696,9 @@ def run_nutrition_extraction(

image = cast(Image.Image, get_image_from_url(image_url))
ocr_result = cast(OCRResult, OCRResult.from_url(image_url.replace(".jpg", ".json")))
prediction = predict(image, ocr_result, triton_uri=triton_uri)
prediction = predict(
image, ocr_result, model_version=model_version, triton_uri=triton_uri
)
if prediction is not None:
pprint(prediction)
else:
Expand Down
16 changes: 9 additions & 7 deletions robotoff/prediction/nutrition_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class NutritionExtractionPrediction:
def predict(
image: Image.Image,
ocr_result: OCRResult,
model_version: str = "1",
model_version: str | None = None,
triton_uri: str | None = None,
) -> NutritionExtractionPrediction | None:
"""Predict the nutrient values from an image and an OCR result.
Expand All @@ -77,7 +77,7 @@ def predict(
:param image: the *original* image (not resized)
:param ocr_result: the OCR result
:param model_version: the version of the model to use, defaults to "1"
:param model_version: the version of the model to use, defaults to None (latest)
:param triton_uri: the URI of the Triton Inference Server, if not provided, the
default value from settings is used
:return: a `NutritionExtractionPrediction` object
Expand Down Expand Up @@ -619,7 +619,7 @@ def send_infer_request(
pixel_values: np.ndarray,
model_name: str,
triton_stub: GRPCInferenceServiceStub,
model_version: str = "1",
model_version: str | None = None,
) -> np.ndarray:
"""Send a NER infer request to the Triton inference server.
Expand All @@ -634,7 +634,7 @@ def send_infer_request(
:param pixel_values: pixel values of the image, generated using the
transformers tokenizer.
:param model_name: the name of the model to use
:param model_version: version of the model model to use, defaults to "1"
:param model_version: version of the model model to use, defaults to None (latest).
:return: the predicted logits
"""
request = build_triton_request(
Expand All @@ -660,7 +660,7 @@ def build_triton_request(
bbox: np.ndarray,
pixel_values: np.ndarray,
model_name: str,
model_version: str = "1",
model_version: str | None = None,
):
"""Build a Triton ModelInferRequest gRPC request for LayoutLMv3 models.
Expand All @@ -672,12 +672,14 @@ def build_triton_request(
:param pixel_values: pixel values of the image, generated using the
transformers tokenizer.
:param model_name: the name of the model to use.
:param model_version: version of the model model to use, defaults to "1".
:param model_version: version of the model model to use, defaults to None (latest).
:return: the gRPC ModelInferRequest
"""
request = service_pb2.ModelInferRequest()
request.model_name = model_name
request.model_version = model_version

if model_version:
request.model_version = model_version

add_triton_infer_input_tensor(request, "input_ids", input_ids, "INT64")
add_triton_infer_input_tensor(request, "attention_mask", attention_mask, "INT64")
Expand Down

0 comments on commit 4f13fbb

Please sign in to comment.