Skip to content

Commit c008acd

Browse files
committed
feat: add command to download models from HF
1 parent 826d536 commit c008acd

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

robotoff/cli/triton.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,13 @@ def unload_model(model_name: str):
6868
typer.echo("Done.")
6969
typer.echo("**Current models (after) **")
7070
list_models()
71+
72+
73+
@app.command()
74+
def download_models():
75+
"""Download all models."""
76+
from robotoff import triton
77+
from robotoff.utils import get_logger
78+
79+
get_logger()
80+
triton.download_models()

robotoff/triton.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import functools
22
import json
3+
import shutil
34
import struct
5+
import tempfile
6+
from pathlib import Path
47

58
import grpc
69
import numpy as np
710
from google.protobuf.json_format import MessageToJson
11+
from huggingface_hub import snapshot_download
812
from more_itertools import chunked
913
from openfoodfacts.types import JSONType
1014
from PIL import Image
15+
from pydantic import BaseModel
1116
from transformers import CLIPImageProcessor
1217
from tritonclient.grpc import service_pb2, service_pb2_grpc
1318
from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub
@@ -24,6 +29,30 @@
2429
# Get model config: /v2/models/{MODEL_NAME}/config
2530

2631

32+
class HuggingFaceModel(BaseModel):
33+
name: str
34+
version: int
35+
repo_id: str
36+
subfolder: str = "onnx"
37+
revision: str = "main"
38+
39+
40+
HUGGINGFACE_MODELS = [
41+
HuggingFaceModel(
42+
name="nutrition_extractor",
43+
version=1,
44+
repo_id="openfoodfacts/nutrition-extractor",
45+
revision="dea426bf3c3d289ad7b65d29a7744ea6851632a6",
46+
),
47+
HuggingFaceModel(
48+
name="nutrition_extractor",
49+
version=2,
50+
repo_id="openfoodfacts/nutrition-extractor",
51+
revision="7a43f38725f50f37a8c7bce417fc75741bea49fe",
52+
),
53+
]
54+
55+
2756
@functools.cache
2857
def get_triton_inference_stub(
2958
triton_uri: str | None = None,
@@ -264,3 +293,34 @@ def get_model_config(
264293

265294
response = triton_stub.ModelConfig(request)
266295
return response.config
296+
297+
298+
def download_models():
299+
"""Downloading all models from Hugging Face Hub.
300+
301+
The models are downloaded in the Triton models directory. If the model
302+
already exists, it is not downloaded.
303+
"""
304+
for model in HUGGINGFACE_MODELS:
305+
base_model_dir = settings.TRITON_MODELS_DIR / model.name
306+
base_model_dir.mkdir(parents=True, exist_ok=True)
307+
model_with_version_dir = base_model_dir / str(model.version) / "model.onnx"
308+
309+
if model_with_version_dir.exists():
310+
logger.info(
311+
f"Model {model.name} version {model.version} already downloaded"
312+
)
313+
continue
314+
315+
with tempfile.TemporaryDirectory() as temp_dir_str:
316+
logger.info(f"Temporary cache directory: {temp_dir_str}")
317+
temp_dir = Path(temp_dir_str)
318+
snapshot_download(
319+
repo_id=model.repo_id,
320+
allow_patterns=[f"{model.subfolder}/*"],
321+
revision=model.revision,
322+
local_dir=temp_dir,
323+
)
324+
model_with_version_dir.parent.mkdir(parents=True, exist_ok=True)
325+
logger.info(f"Copying model files to {model_with_version_dir}")
326+
shutil.move(temp_dir / model.subfolder, model_with_version_dir)

0 commit comments

Comments
 (0)