Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,11 @@ def list(
aqua_models = []
inference_containers = self.get_container_config().to_dict().get("inference")
for model in models:
# Skip models without required tags early
freeform_tags = model.freeform_tags or {}
if Tags.AQUA_TAG.lower() not in {tag.lower() for tag in freeform_tags}:
continue

aqua_models.append(
AquaModelSummary(
**self._process_model(
Expand All @@ -1121,6 +1126,8 @@ def list(
project_id=project_id or UNKNOWN,
)
)

# Adds service models to cache
if category == SERVICE:
self._service_models_cache.__setitem__(
key=AQUA_SERVICE_MODELS, value=aqua_models
Expand Down
41 changes: 41 additions & 0 deletions tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,3 +1654,44 @@ def test_build_search_text(self, description, tags, expected_output):
self.app._build_search_text(tags=tags, description=description)
== expected_output
)

@pytest.mark.parametrize(
"remove_indices, expected_len",
[
([], 2), # All models have AQUA_TAG -> include both
([1], 1), # Second model missing AQUA_TAG -> include first only
([0, 1], 0), # Both missing AQUA_TAG -> include none
],
)
@patch.object(AquaApp, "get_container_config")
def test_list_service_models_filters_missing_aqua_tag(
self,
mock_get_container_config,
remove_indices,
expected_len,
):
"""Ensure list() excludes models that do not have AQUA_TAG in freeform_tags."""
mock_get_container_config.return_value = get_container_config()

import copy

items = copy.deepcopy(TestDataset.model_summary_objects)
for idx in remove_indices:
# remove AQUA tag entirely to validate filter behavior
items[idx]["freeform_tags"].pop("OCI_AQUA", None)

self.app.list_resource = MagicMock(
return_value=[
oci.data_science.models.ModelSummary(**item) for item in items
]
)

# Clear service models cache
self.app.clear_model_list_cache()

results = self.app.list(
compartment_id=TestDataset.SERVICE_COMPARTMENT_ID,
category=ads.config.SERVICE,
)

assert len(results) == expected_len
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
#!/usr/bin/env python

# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
# Copyright (c) 2021, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

"""Unit tests for model frameworks. Includes tests for:
- PyTorchModel
- PyTorchModel
"""

import base64
import os
import shutil
import uuid
from io import BytesIO

import numpy as np
Expand All @@ -19,7 +21,7 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import uuid

from ads.model.framework.pytorch_model import PyTorchModel
from ads.model.serde.model_serializer import (
PyTorchOnnxModelSaveSERDE,
Expand Down Expand Up @@ -146,6 +148,9 @@ def test_serialize_with_incorrect_model_file_name_onnx(self):
as_onnx=True, model_file_name="model.xxx"
)

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_serialize_using_pytorch_without_modelname(self):
"""
Test serialize_model using pytorch without model_file_name
Expand All @@ -157,6 +162,9 @@ def test_serialize_using_pytorch_without_modelname(self):
test_pytorch_model.serialize_model(as_onnx=False)
assert os.path.isfile(tmp_model_dir + "model.pt")

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_serialize_using_pytorch_with_modelname(self):
"""
Test serialize_model using pytorch with correct model_file_name
Expand All @@ -169,6 +177,9 @@ def test_serialize_using_pytorch_with_modelname(self):
test_pytorch_model.serialize_model(as_onnx=False)
assert os.path.isfile(tmp_model_dir + "test1.pt")

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_serialize_using_onnx_without_modelname(self):
"""
Test serialize_model using onnx without model_file_name
Expand All @@ -183,6 +194,9 @@ def test_serialize_using_onnx_without_modelname(self):
)
assert os.path.exists(os.path.join(tmp_model_dir, "model.onnx"))

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_serialize_using_onnx_with_modelname(self):
"""
Test serialize_model using onnx with correct model_file_name
Expand All @@ -200,6 +214,9 @@ def test_serialize_using_onnx_with_modelname(self):
os.path.join(tmp_model_dir, test_pytorch_model.model_file_name)
)

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_to_onnx(self):
"""
Test if PytorchOnnxModelSerializer.serialize generate onnx model result.
Expand All @@ -216,6 +233,9 @@ def test_to_onnx(self):
)
assert os.path.exists(os.path.join(tmp_model_dir, model_file_name))

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_to_onnx_reload(self):
"""
Test if PytorchOnnxModelSerializer.serialize generate onnx model result.
Expand All @@ -235,6 +255,9 @@ def test_to_onnx_reload(self):
is not None
)

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_to_onnx_without_dummy_input(self):
"""
Test if PytorchOnnxModelSerializer.serialize raise expected error
Expand Down Expand Up @@ -324,6 +347,9 @@ def test_prepare_default(self):
)
assert os.path.exists(tmp_model_dir + "model.pt")

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_prepare_onnx(self):
test_pytorch_model = PyTorchModel(self.myPyTorchModel, tmp_model_dir)
test_pytorch_model.prepare(
Expand All @@ -335,6 +361,9 @@ def test_prepare_onnx(self):
)
assert os.path.exists(tmp_model_dir + "model.onnx")

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_prepare_onnx_with_X_sample(self):
test_pytorch_model = PyTorchModel(self.myPyTorchModel, tmp_model_dir)
test_pytorch_model.prepare(
Expand All @@ -346,6 +375,9 @@ def test_prepare_onnx_with_X_sample(self):
)
assert isinstance(test_pytorch_model.verify([1, 2, 3, 4]), dict)

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_prepare_onnx_without_input(self):
test_pytorch_model = PyTorchModel(self.myPyTorchModel, tmp_model_dir)
with pytest.raises(ValueError):
Expand All @@ -356,6 +388,9 @@ def test_prepare_onnx_without_input(self):
as_onnx=True,
)

@pytest.mark.skip(
reason="ODSC-79463: Fix Missing onnxscript Dependency Causing ONNX Serialization Test Failures"
)
def test_verify_onnx(self):
"""
Test if PyTorchModel.verify in onnx serialization
Expand Down
Loading