Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 29 additions & 19 deletions litellm/proxy/guardrails/guardrail_hooks/pangea/pangea.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def _call_pangea_ai_guard(
should act on.

Args:
api (str): Which API to use (text/guard or v1beta/guard)
api (str): Which API to use (text/guard or v1/guard)
payload (dict): The request payload.
request_data (dict): Original request data (used for logging/headers).
hook_name (str): Name of the hook calling this function (for logging).
Expand Down Expand Up @@ -163,7 +163,7 @@ async def _async_pre_call_hook(
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str
call_type: str,
):
transformer = None
messages: Any = None
Expand All @@ -173,19 +173,23 @@ async def _async_pre_call_hook(
else:
messages = data.get("messages")

input_dict = {
"messages": messages, # type: ignore
}
if data.get("tools"):
input_dict["tools"] = data.get("tools")

ai_guard_payload = {
"debug": False,
"input": {
"messages": messages, # type: ignore
"tools": data.get("tools")
},
"input": input_dict,
"event_type": "input",
}

if self.pangea_input_recipe:
ai_guard_payload["recipe"] = self.pangea_input_recipe

ai_guard_response = await self._call_pangea_ai_guard(
"v1beta/guard", ai_guard_payload, "async_pre_call_hook"
"v1/guard", ai_guard_payload, "async_pre_call_hook"
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
Expand All @@ -196,12 +200,11 @@ async def _async_pre_call_hook(

output = ai_guard_response.get("result", {}).get("output", {})
if call_type == "text_completion" or call_type == "atext_completion":
data = transformer.update_original_body(output["messages"]) # type: ignore
data = transformer.update_original_body(output["messages"]) # type: ignore
else:
data["messages"] = output["messages"]
return data


@log_guardrail_information
async def async_pre_call_hook(
self,
Expand All @@ -218,7 +221,9 @@ async def async_pre_call_hook(
return data

try:
return await self._async_pre_call_hook(user_api_key_dict, cache, data, call_type)
return await self._async_pre_call_hook(
user_api_key_dict, cache, data, call_type
)
except HTTPException:
raise
except Exception as e:
Expand All @@ -228,7 +233,7 @@ async def async_pre_call_hook(
"error": "Error in Pangea Guardrail",
"guardrail_name": self.guardrail_name,
"exceptions": str(e),
}
},
) from e

async def _async_post_call_success_hook(
Expand Down Expand Up @@ -259,21 +264,24 @@ async def _async_post_call_success_hook(
serialized_choices.append(c)
choices = serialized_choices

input_dict = {
"messages": input_messages,
"choices": choices,
}
if data.get("tools"):
input_dict["tools"] = data.get("tools")

ai_guard_payload = {
"debug": False,
"input": {
"messages": input_messages,
"tools": data.get("tools"),
"choices": choices,
},
"input": input_dict,
"event_type": "output",
}

if self.pangea_output_recipe:
ai_guard_payload["recipe"] = self.pangea_output_recipe

ai_guard_response = await self._call_pangea_ai_guard(
"v1beta/guard", ai_guard_payload, "async_pre_call_hook"
"v1/guard", ai_guard_payload, "async_pre_call_hook"
)
add_guardrail_to_applied_guardrails_header(
request_data=data, guardrail_name=self.guardrail_name
Expand Down Expand Up @@ -309,7 +317,9 @@ async def async_post_call_success_hook(
)
return data
try:
return await self._async_post_call_success_hook(data, user_api_key_dict, response)
return await self._async_post_call_success_hook(
data, user_api_key_dict, response
)
except HTTPException:
raise
except Exception as e:
Expand All @@ -319,7 +329,7 @@ async def async_post_call_success_hook(
"error": "Error in Pangea Guardrail",
"guardrail_name": self.guardrail_name,
"exceptions": str(e),
}
},
) from e

@staticmethod
Expand Down
Loading