|
1 | 1 | import functools
|
2 | 2 | import json
|
| 3 | +import shutil |
3 | 4 | import struct
|
| 5 | +import tempfile |
| 6 | +from pathlib import Path |
4 | 7 |
|
5 | 8 | import grpc
|
6 | 9 | import numpy as np
|
7 | 10 | from google.protobuf.json_format import MessageToJson
|
| 11 | +from huggingface_hub import snapshot_download |
8 | 12 | from more_itertools import chunked
|
9 | 13 | from openfoodfacts.types import JSONType
|
10 | 14 | from PIL import Image
|
| 15 | +from pydantic import BaseModel |
11 | 16 | from transformers import CLIPImageProcessor
|
12 | 17 | from tritonclient.grpc import service_pb2, service_pb2_grpc
|
13 | 18 | from tritonclient.grpc.service_pb2_grpc import GRPCInferenceServiceStub
|
|
24 | 29 | # Get model config: /v2/models/{MODEL_NAME}/config
|
25 | 30 |
|
26 | 31 |
|
| 32 | +class HuggingFaceModel(BaseModel): |
| 33 | + name: str |
| 34 | + version: int |
| 35 | + repo_id: str |
| 36 | + subfolder: str = "onnx" |
| 37 | + revision: str = "main" |
| 38 | + |
| 39 | + |
| 40 | +HUGGINGFACE_MODELS = [ |
| 41 | + HuggingFaceModel( |
| 42 | + name="nutrition_extractor", |
| 43 | + version=1, |
| 44 | + repo_id="openfoodfacts/nutrition-extractor", |
| 45 | + revision="dea426bf3c3d289ad7b65d29a7744ea6851632a6", |
| 46 | + ), |
| 47 | + HuggingFaceModel( |
| 48 | + name="nutrition_extractor", |
| 49 | + version=2, |
| 50 | + repo_id="openfoodfacts/nutrition-extractor", |
| 51 | + revision="7a43f38725f50f37a8c7bce417fc75741bea49fe", |
| 52 | + ), |
| 53 | +] |
| 54 | + |
| 55 | + |
27 | 56 | @functools.cache
|
28 | 57 | def get_triton_inference_stub(
|
29 | 58 | triton_uri: str | None = None,
|
@@ -264,3 +293,34 @@ def get_model_config(
|
264 | 293 |
|
265 | 294 | response = triton_stub.ModelConfig(request)
|
266 | 295 | return response.config
|
| 296 | + |
| 297 | + |
| 298 | +def download_models(): |
| 299 | + """Downloading all models from Hugging Face Hub. |
| 300 | +
|
| 301 | + The models are downloaded in the Triton models directory. If the model |
| 302 | + already exists, it is not downloaded. |
| 303 | + """ |
| 304 | + for model in HUGGINGFACE_MODELS: |
| 305 | + base_model_dir = settings.TRITON_MODELS_DIR / model.name |
| 306 | + base_model_dir.mkdir(parents=True, exist_ok=True) |
| 307 | + model_with_version_dir = base_model_dir / str(model.version) / "model.onnx" |
| 308 | + |
| 309 | + if model_with_version_dir.exists(): |
| 310 | + logger.info( |
| 311 | + f"Model {model.name} version {model.version} already downloaded" |
| 312 | + ) |
| 313 | + continue |
| 314 | + |
| 315 | + with tempfile.TemporaryDirectory() as temp_dir_str: |
| 316 | + logger.info(f"Temporary cache directory: {temp_dir_str}") |
| 317 | + temp_dir = Path(temp_dir_str) |
| 318 | + snapshot_download( |
| 319 | + repo_id=model.repo_id, |
| 320 | + allow_patterns=[f"{model.subfolder}/*"], |
| 321 | + revision=model.revision, |
| 322 | + local_dir=temp_dir, |
| 323 | + ) |
| 324 | + model_with_version_dir.parent.mkdir(parents=True, exist_ok=True) |
| 325 | + logger.info(f"Copying model files to {model_with_version_dir}") |
| 326 | + shutil.move(temp_dir / model.subfolder, model_with_version_dir) |
0 commit comments