1
1
#!/usr/bin/env python
2
- # Copyright (c) 2024 Oracle and/or its affiliates.
2
+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
3
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
import os
5
5
import pathlib
15
15
from ads .aqua import ODSC_MODEL_COMPARTMENT_OCID , logger
16
16
from ads .aqua .app import AquaApp
17
17
from ads .aqua .common .enums import (
18
+ CustomInferenceContainerTypeFamily ,
18
19
FineTuningContainerTypeFamily ,
19
20
InferenceContainerTypeFamily ,
20
21
Tags ,
23
24
from ads .aqua .common .utils import (
24
25
LifecycleStatus ,
25
26
_build_resource_identifier ,
27
+ cleanup_local_hf_model_artifact ,
26
28
copy_model_config ,
27
29
create_word_icon ,
28
30
generate_tei_cmd_var ,
@@ -376,8 +378,10 @@ def delete_model(self, model_id):
376
378
f"Failed to delete model:{ model_id } . Only registered models or finetuned model can be deleted."
377
379
)
378
380
379
- @telemetry (entry_point = "plugin=model&action=delete" , name = "aqua" )
380
- def edit_registered_model (self , id , inference_container , enable_finetuning , task ):
381
+ @telemetry (entry_point = "plugin=model&action=edit" , name = "aqua" )
382
+ def edit_registered_model (
383
+ self , id , inference_container , inference_container_uri , enable_finetuning , task
384
+ ):
381
385
"""Edits the default config of unverified registered model.
382
386
383
387
Parameters
@@ -386,6 +390,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
386
390
The model OCID.
387
391
inference_container: str.
388
392
The inference container family name
393
+ inference_container_uri: str
394
+ The inference container uri for embedding models
389
395
enable_finetuning: str
390
396
Flag to enable or disable finetuning over the model. Defaults to None
391
397
task:
@@ -401,19 +407,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
401
407
if ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None ):
402
408
if ds_model .freeform_tags .get (Tags .AQUA_SERVICE_MODEL_TAG , None ):
403
409
raise AquaRuntimeError (
404
- f"Failed to edit model: { id } . Only registered unverified models can be edited."
410
+ " Only registered unverified models can be edited."
405
411
)
406
412
else :
407
413
custom_metadata_list = ds_model .custom_metadata_list
408
414
freeform_tags = ds_model .freeform_tags
409
415
if inference_container :
410
- custom_metadata_list .add (
411
- key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
412
- value = inference_container ,
413
- category = MetadataCustomCategory .OTHER ,
414
- description = "Deployment container mapping for SMC" ,
415
- replace = True ,
416
- )
416
+ if (
417
+ inference_container in CustomInferenceContainerTypeFamily
418
+ and inference_container_uri is None
419
+ ):
420
+ raise AquaRuntimeError (
421
+ "Inference container URI must be provided."
422
+ )
423
+ else :
424
+ custom_metadata_list .add (
425
+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
426
+ value = inference_container ,
427
+ category = MetadataCustomCategory .OTHER ,
428
+ description = "Deployment container mapping for SMC" ,
429
+ replace = True ,
430
+ )
431
+ if inference_container_uri :
432
+ if (
433
+ inference_container in CustomInferenceContainerTypeFamily
434
+ or inference_container is None
435
+ ):
436
+ custom_metadata_list .add (
437
+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER_URI ,
438
+ value = inference_container_uri ,
439
+ category = MetadataCustomCategory .OTHER ,
440
+ description = f"Inference container URI for { ds_model .display_name } " ,
441
+ replace = True ,
442
+ )
443
+ else :
444
+ raise AquaRuntimeError (
445
+ f"Inference container URI can be edited only with container values: { CustomInferenceContainerTypeFamily .values ()} "
446
+ )
447
+
417
448
if enable_finetuning is not None :
418
449
if enable_finetuning .lower () == "true" :
419
450
custom_metadata_list .add (
@@ -448,9 +479,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task
448
479
)
449
480
AquaApp ().update_model (id , update_model_details )
450
481
else :
451
- raise AquaRuntimeError (
452
- f"Failed to edit model:{ id } . Only registered unverified models can be edited."
453
- )
482
+ raise AquaRuntimeError ("Only registered unverified models can be edited." )
454
483
455
484
def _fetch_metric_from_metadata (
456
485
self ,
@@ -869,8 +898,7 @@ def _create_model_catalog_entry(
869
898
# only add cmd vars if inference container is not an SMC
870
899
if (
871
900
inference_container not in smc_container_set
872
- and inference_container
873
- == InferenceContainerTypeFamily .AQUA_TEI_CONTAINER_FAMILY
901
+ and inference_container in CustomInferenceContainerTypeFamily .values ()
874
902
):
875
903
cmd_vars = generate_tei_cmd_var (os_path )
876
904
metadata .add (
@@ -1322,20 +1350,20 @@ def _download_model_from_hf(
1322
1350
Returns
1323
1351
-------
1324
1352
model_artifact_path (str): Location where the model artifacts are downloaded.
1325
-
1326
1353
"""
1327
1354
# Download the model from hub
1328
- if not local_dir :
1329
- local_dir = os .path .join (os .path .expanduser ("~" ), "cached-model" )
1330
- local_dir = os .path .join (local_dir , model_name )
1331
- os .makedirs (local_dir , exist_ok = True )
1332
- snapshot_download (
1355
+ if local_dir :
1356
+ local_dir = os .path .join (local_dir , model_name )
1357
+ os .makedirs (local_dir , exist_ok = True )
1358
+
1359
+ # if local_dir is not set, the return value points to the cached data folder
1360
+ local_dir = snapshot_download (
1333
1361
repo_id = model_name ,
1334
1362
local_dir = local_dir ,
1335
1363
allow_patterns = allow_patterns ,
1336
1364
ignore_patterns = ignore_patterns ,
1337
1365
)
1338
- # Upload to object storage and skip .cache/huggingface/ folder
1366
+ # Upload to object storage
1339
1367
model_artifact_path = upload_folder (
1340
1368
os_path = os_path ,
1341
1369
local_dir = local_dir ,
@@ -1365,6 +1393,8 @@ def register(
1365
1393
ignore_patterns (list): Model files matching any of the patterns are not downloaded.
1366
1394
Example: ["*.json"] will ignore all .json files. ["folder/*"] will ignore all files under `folder`.
1367
1395
Patterns are Standard Wildcards (globbing patterns) and rules can be found here: https://docs.python.org/3/library/fnmatch.html
1396
+ cleanup_model_cache (bool): Deletes downloaded files from local machine after model is successfully
1397
+ registered. Set to True by default.
1368
1398
1369
1399
Returns:
1370
1400
AquaModel:
@@ -1474,6 +1504,14 @@ def register(
1474
1504
detail = validation_result .telemetry_model_name ,
1475
1505
)
1476
1506
1507
+ if (
1508
+ import_model_details .download_from_hf
1509
+ and import_model_details .cleanup_model_cache
1510
+ ):
1511
+ cleanup_local_hf_model_artifact (
1512
+ model_name = model_name , local_dir = import_model_details .local_dir
1513
+ )
1514
+
1477
1515
return AquaModel (** aqua_model_attributes )
1478
1516
1479
1517
def _if_show (self , model : DataScienceModel ) -> bool :
0 commit comments