diff --git a/src/sagemaker/jumpstart/hub/utils.py b/src/sagemaker/jumpstart/hub/utils.py index 77540926c6..edc3c08fa7 100644 --- a/src/sagemaker/jumpstart/hub/utils.py +++ b/src/sagemaker/jumpstart/hub/utils.py @@ -106,7 +106,7 @@ def construct_hub_model_reference_arn_from_inputs( info = get_info_from_hub_resource_arn(hub_arn) arn = ( f"arn:{info.partition}:sagemaker:{info.region}:{info.account_id}:hub-content/" - f"{info.hub_name}/{HubContentType.MODEL_REFERENCE}/{model_name}/{version}" + f"{info.hub_name}/{HubContentType.MODEL_REFERENCE.value}/{model_name}/{version}" ) return arn diff --git a/tests/unit/sagemaker/jumpstart/hub/test_utils.py b/tests/unit/sagemaker/jumpstart/hub/test_utils.py index 22bc527b18..6dbb1340f4 100644 --- a/tests/unit/sagemaker/jumpstart/hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/hub/test_utils.py @@ -96,6 +96,23 @@ def test_construct_hub_model_arn_from_inputs(): ) +def test_construct_hub_model_reference_arn_from_inputs(): + model_name, version = "pytorch-ic-imagenet-v2", "1.0.2" + hub_arn = "arn:aws:sagemaker:us-west-2:123456789123:hub/my-mock-hub" + hub_content_arn_prefix = "arn:aws:sagemaker:us-west-2:123456789123:hub-content/my-mock-hub" + + assert ( + utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version) + == hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/1.0.2" + ) + + version = "*" + assert ( + utils.construct_hub_model_reference_arn_from_inputs(hub_arn, model_name, version) + == hub_content_arn_prefix + "/ModelReference/pytorch-ic-imagenet-v2/*" + ) + + def test_generate_hub_arn_for_init_kwargs(): hub_name = "my-hub-name" hub_arn = "arn:aws:sagemaker:us-west-2:12346789123:hub/my-awesome-hub"