Skip to content

Commit 214c10f

Browse files
committed
test_completion_cost_databricks_embedding
1 parent d475557 commit 214c10f

File tree

3 files changed

+52
-58
lines changed

3 files changed

+52
-58
lines changed

tests/local_testing/test_async_fn.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -91,35 +91,6 @@ async def test_get_response():
9191
# test_async_response_openai()
9292

9393

94-
def test_async_response_azure():
95-
import asyncio
96-
97-
litellm.set_verbose = True
98-
99-
async def test_get_response():
100-
user_message = "What do you know?"
101-
messages = [{"content": user_message, "role": "user"}]
102-
try:
103-
response = await acompletion(
104-
model="azure/gpt-turbo",
105-
messages=messages,
106-
base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"),
107-
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
108-
)
109-
print(f"response: {response}")
110-
except litellm.Timeout as e:
111-
pass
112-
except litellm.InternalServerError:
113-
pass
114-
except Exception as e:
115-
pytest.fail(f"An exception occurred: {e}")
116-
117-
asyncio.run(test_get_response())
118-
119-
120-
# test_async_response_azure()
121-
122-
12394
@pytest.mark.skip(reason="anyscale stopped serving public api endpoints")
12495
def test_async_anyscale_response():
12596
import asyncio

tests/local_testing/test_completion.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3334,31 +3334,6 @@ def test_completion_anyscale_api():
33343334
pytest.fail(f"Error occurred: {e}")
33353335

33363336

3337-
# test_completion_anyscale_api()
3338-
def test_azure_cloudflare_api():
3339-
litellm.set_verbose = True
3340-
try:
3341-
messages = [
3342-
{
3343-
"role": "user",
3344-
"content": "How do I output all files in a directory using Python?",
3345-
},
3346-
]
3347-
response = completion(
3348-
model="azure/gpt-turbo",
3349-
messages=messages,
3350-
base_url=os.getenv("CLOUDFLARE_AZURE_BASE_URL"),
3351-
api_key=os.getenv("AZURE_FRANCE_API_KEY"),
3352-
)
3353-
print(f"response: {response}")
3354-
except Exception as e:
3355-
pytest.fail(f"Error occurred: {e}")
3356-
traceback.print_exc()
3357-
pass
3358-
3359-
3360-
# test_azure_cloudflare_api()
3361-
33623337

33633338
@pytest.mark.skip(reason="anyscale stopped serving public api endpoints")
33643339
def test_completion_anyscale_2():

tests/local_testing/test_completion_cost.py

Lines changed: 52 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
model_cost,
2525
open_ai_chat_completion_models,
2626
)
27+
from litellm.llms.custom_httpx.http_handler import HTTPHandler
28+
import json
29+
import httpx
2730
from litellm.types.utils import PromptTokensDetails
2831
from litellm.litellm_core_utils.litellm_logging import CustomLogger
2932

@@ -1148,13 +1151,58 @@ def test_completion_cost_databricks(model):
11481151
"databricks/databricks-gte-large-en",
11491152
],
11501153
)
1151-
def test_completion_cost_databricks_embedding(model):
1154+
def test_completion_cost_databricks_embedding(model, monkeypatch):
1155+
"""
1156+
Test completion cost calculation for Databricks embedding models using mocked HTTP responses.
1157+
"""
1158+
base_url = "https://my.workspace.cloud.databricks.com/serving-endpoints"
1159+
api_key = "dapimykey"
1160+
monkeypatch.setenv("DATABRICKS_API_BASE", base_url)
1161+
monkeypatch.setenv("DATABRICKS_API_KEY", api_key)
1162+
11521163
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"
11531164
litellm.model_cost = litellm.get_model_cost_map(url="")
1154-
resp = litellm.embedding(model=model, input=["hey, how's it going?"]) # works fine
1165+
1166+
mock_response_data = {
1167+
"object": "list",
1168+
"model": model.split("/")[1],
1169+
"data": [
1170+
{
1171+
"index": 0,
1172+
"object": "embedding",
1173+
"embedding": [
1174+
0.06768798828125,
1175+
-0.01291656494140625,
1176+
-0.0501708984375,
1177+
0.0245361328125,
1178+
-0.030364990234375,
1179+
],
1180+
}
1181+
],
1182+
"usage": {
1183+
"prompt_tokens": 8,
1184+
"total_tokens": 8,
1185+
"completion_tokens": 0,
1186+
"completion_tokens_details": None,
1187+
"prompt_tokens_details": None,
1188+
},
1189+
}
1190+
1191+
mock_response = MagicMock(spec=httpx.Response)
1192+
mock_response.status_code = 200
1193+
mock_response.json.return_value = mock_response_data
1194+
1195+
sync_handler = HTTPHandler()
1196+
1197+
with patch.object(HTTPHandler, "post", return_value=mock_response):
1198+
resp = litellm.embedding(
1199+
model=model,
1200+
input=["hey, how's it going?"],
1201+
client=sync_handler
1202+
)
11551203

1156-
print(resp)
1157-
cost = completion_cost(completion_response=resp)
1204+
print(resp)
1205+
cost = completion_cost(completion_response=resp)
11581206

11591207

11601208
from litellm.llms.fireworks_ai.cost_calculator import get_base_model_for_pricing

0 commit comments

Comments
 (0)