Skip to content

Commit 0458dcf

Browse files
shubhiroyPouyanpi
andcommitted
feat: Add Azure OpenAI embedding provider (#702)
--------- Signed-off-by: Pouyan <[email protected]> Co-authored-by: Pouyanpi <[email protected]>
1 parent 05afc1b commit 0458dcf

File tree

7 files changed

+453
-1
lines changed

7 files changed

+453
-1
lines changed

docs/user-guides/configuration-guide.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,7 @@ The following tables lists the supported embedding providers:
538538
| OpenAI | `openai` | `text-embedding-ada-002`, etc. |
539539
| SentenceTransformers | `SentenceTransformers` | `all-MiniLM-L6-v2`, etc. |
540540
| NVIDIA AI Endpoints | `nvidia_ai_endpoints` | `nv-embed-v1`, etc. |
541+
| AzureOpenAI | `AzureOpenAI` | `text-embedding-ada-002`, etc.
541542

542543
```{note}
543544
You can use any of the supported models for any of the supported embedding providers.

nemoguardrails/embeddings/providers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from typing import Optional, Type
2020

21-
from . import fastembed, nim, openai, sentence_transformers
21+
from . import azureopenai, fastembed, nim, openai, sentence_transformers
2222
from .base import EmbeddingModel
2323
from .registry import EmbeddingProviderRegistry
2424

@@ -65,6 +65,7 @@ def register_embedding_provider(
6565

6666
register_embedding_provider(fastembed.FastEmbedEmbeddingModel)
6767
register_embedding_provider(openai.OpenAIEmbeddingModel)
68+
register_embedding_provider(azureopenai.AzureEmbeddingModel)
6869
register_embedding_provider(sentence_transformers.SentenceTransformerEmbeddingModel)
6970
register_embedding_provider(nim.NIMEmbeddingModel)
7071
register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import asyncio
17+
import os
18+
from typing import List
19+
20+
from .base import EmbeddingModel
21+
22+
23+
def get_executor():
24+
from . import embeddings_executor
25+
26+
return embeddings_executor
27+
28+
29+
class AzureEmbeddingModel(EmbeddingModel):
30+
"""Embedding model using Azure OpenAI.
31+
32+
This class represents an embedding model that utilizes the Azure OpenAI API
33+
for generating text embeddings.
34+
35+
Args:
36+
embedding_model (str): The name of the Azure OpenAI deployment model (e.g., "text-embedding-ada-002").
37+
"""
38+
39+
engine_name = "AzureOpenAI"
40+
41+
# Lookup table for model embedding dimensions
42+
MODEL_DIMENSIONS = {
43+
"text-embedding-ada-002": 1536,
44+
# Add more models and their dimensions here if needed
45+
}
46+
47+
def __init__(self, embedding_model: str):
48+
try:
49+
from openai import AzureOpenAI
50+
except ImportError:
51+
raise ImportError(
52+
"Could not import openai, please install it with "
53+
"`pip install openai`."
54+
)
55+
# Set Azure OpenAI API credentials
56+
self.client = AzureOpenAI(
57+
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
58+
api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
59+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
60+
)
61+
62+
self.embedding_model = embedding_model
63+
self.embedding_size = self._get_embedding_dimension()
64+
65+
def _get_embedding_dimension(self):
66+
"""Retrieve the embedding dimension for the specified model."""
67+
if self.embedding_model in self.MODEL_DIMENSIONS:
68+
return self.MODEL_DIMENSIONS[self.embedding_model]
69+
else:
70+
embedding_size = len(self.encode(["test"])[0])
71+
return embedding_size
72+
73+
async def encode_async(self, documents: List[str]) -> List[List[float]]:
74+
"""Asynchronously encode a list of documents into their corresponding embeddings.
75+
76+
Args:
77+
documents (List[str]): The list of documents to be encoded.
78+
79+
Returns:
80+
List[List[float]]: The list of embeddings, where each embedding is a list of floats.
81+
"""
82+
loop = asyncio.get_running_loop()
83+
result = await loop.run_in_executor(get_executor(), self.encode, documents)
84+
return result
85+
86+
def encode(self, documents: List[str]) -> List[List[float]]:
87+
"""Encode a list of documents into their corresponding embeddings.
88+
89+
Args:
90+
documents (List[str]): The list of documents to be encoded.
91+
92+
Returns:
93+
List[List[float]]: The list of embeddings, where each embedding is a list of floats.
94+
95+
Raises:
96+
RuntimeError: If the API call fails.
97+
"""
98+
try:
99+
response = self.client.embeddings.create(
100+
model=self.embedding_model, input=documents
101+
)
102+
embeddings = [record.embedding for record in response.data]
103+
return embeddings
104+
except Exception as e:
105+
raise RuntimeError(f"Failed to retrieve embeddings: {e}")
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
define user ask capabilities
2+
"What can you do?"
3+
"What can you help me with?"
4+
"tell me what you can do"
5+
"tell me about you"
6+
7+
define bot inform capabilities
8+
"I am an AI assistant that helps answer questions."
9+
10+
define flow
11+
user ask capabilities
12+
bot inform capabilities
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
models:
2+
- type: main
3+
engine: azure
4+
model: gpt-4o
5+
6+
- type: embeddings
7+
engine: AzureOpenAI
8+
model: text-embedding-ada-002
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
18+
import pytest
19+
20+
from nemoguardrails import LLMRails, RailsConfig
21+
22+
try:
23+
from nemoguardrails.embeddings.providers.azureopenai import AzureEmbeddingModel
24+
except ImportError:
25+
# Ignore this if running in test environment when azureopenai not installed.
26+
AzureEmbeddingModel = None
27+
CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), ".", "test_configs")
28+
29+
LIVE_TEST_MODE = os.environ.get("LIVE_TEST")
30+
31+
32+
@pytest.fixture
33+
def app():
34+
"""Load the configuration where we replace FastEmbed with AzureOpenAI."""
35+
config = RailsConfig.from_path(
36+
os.path.join(CONFIGS_FOLDER, "with_azureopenai_embeddings")
37+
)
38+
39+
return LLMRails(config)
40+
41+
42+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
43+
def test_custom_llm_registration(app):
44+
assert isinstance(
45+
app.llm_generation_actions.flows_index._model, AzureEmbeddingModel
46+
)
47+
48+
49+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
50+
@pytest.mark.asyncio
51+
async def test_live_query_async():
52+
config = RailsConfig.from_path(
53+
os.path.join(CONFIGS_FOLDER, "with_azureopenai_embeddings")
54+
)
55+
app = LLMRails(config)
56+
57+
result = await app.generate_async(
58+
messages=[{"role": "user", "content": "tell me what you can do"}]
59+
)
60+
61+
assert result == {
62+
"role": "assistant",
63+
"content": "I am an AI assistant that helps answer questions.",
64+
}
65+
66+
67+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
68+
def test_live_query_sync(app):
69+
result = app.generate(
70+
messages=[{"role": "user", "content": "tell me what you can do"}]
71+
)
72+
73+
assert result == {
74+
"role": "assistant",
75+
"content": "I am an AI assistant that helps answer questions.",
76+
}
77+
78+
79+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
80+
def test_sync_embeddings():
81+
model = AzureEmbeddingModel("text-embedding-ada-002")
82+
83+
result = model.encode(["test"])
84+
85+
assert len(result[0]) == 1536
86+
87+
88+
@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.")
89+
@pytest.mark.asyncio
90+
async def test_async_embeddings():
91+
model = AzureEmbeddingModel("text-embedding-ada-002")
92+
93+
result = await model.encode_async(["test"])
94+
95+
assert len(result[0]) == 1536

0 commit comments

Comments
 (0)