Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions nemoguardrails/actions/llm/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ async def generate_user_intent(

if tool_calls:
output_events.append(
new_event_dict("BotToolCall", tool_calls=tool_calls)
new_event_dict("BotToolCalls", tool_calls=tool_calls)
)
else:
output_events.append(new_event_dict("BotMessage", text=text))
Expand Down Expand Up @@ -905,9 +905,23 @@ async def generate_bot_message(
LLMCallInfo(task=Task.GENERATE_BOT_MESSAGE.value)
)

# We use the potentially updated $user_message. This means that even
# in passthrough mode, input rails can still alter the input.
prompt = context.get("user_message")
# In passthrough mode, we should use the full conversation history
# instead of just the last user message to preserve tool message context
raw_prompt = raw_llm_request.get()

if raw_prompt is not None and isinstance(raw_prompt, list):
# Use the full conversation including tool messages
prompt = raw_prompt.copy()

# Update the last user message if it was altered by input rails
user_message = context.get("user_message")
if user_message and prompt:
for i in reversed(range(len(prompt))):
if prompt[i]["role"] == "user":
prompt[i]["content"] = user_message
break
else:
prompt = context.get("user_message")

generation_options: GenerationOptions = generation_options_var.get()
with llm_params(
Expand Down
25 changes: 16 additions & 9 deletions nemoguardrails/actions/llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,23 @@ def _convert_messages_to_langchain_format(prompt: List[dict]) -> List:
if msg_type == "user":
messages.append(HumanMessage(content=msg["content"]))
elif msg_type in ["bot", "assistant"]:
messages.append(AIMessage(content=msg["content"]))
tool_calls = msg.get("tool_calls")
if tool_calls:
messages.append(
AIMessage(content=msg["content"], tool_calls=tool_calls)
)
else:
messages.append(AIMessage(content=msg["content"]))
elif msg_type == "system":
messages.append(SystemMessage(content=msg["content"]))
elif msg_type == "tool":
messages.append(
ToolMessage(
content=msg["content"],
tool_call_id=msg.get("tool_call_id", ""),
)
tool_message = ToolMessage(
content=msg["content"],
tool_call_id=msg.get("tool_call_id", ""),
)
if msg.get("name"):
tool_message.name = msg["name"]
messages.append(tool_message)
else:
raise ValueError(f"Unknown message type {msg_type}")

Expand Down Expand Up @@ -674,16 +681,16 @@ def get_and_clear_tool_calls_contextvar() -> Optional[list]:


def extract_tool_calls_from_events(events: list) -> Optional[list]:
"""Extract tool_calls from BotToolCall events.
"""Extract tool_calls from BotToolCalls events.
Args:
events: List of events to search through
Returns:
tool_calls if found in BotToolCall event, None otherwise
tool_calls if found in BotToolCalls event, None otherwise
"""
for event in events:
if event.get("type") == "BotToolCall":
if event.get("type") == "BotToolCalls":
return event.get("tool_calls")
return None

Expand Down
15 changes: 14 additions & 1 deletion nemoguardrails/integrations/langchain/runnable_rails.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.runnables import Runnable, RunnableConfig, RunnableSerializable
Expand Down Expand Up @@ -231,11 +232,23 @@ def _create_passthrough_messages(self, _input) -> List[Dict[str, Any]]:
def _message_to_dict(self, msg: BaseMessage) -> Dict[str, Any]:
"""Convert a BaseMessage to dictionary format."""
if isinstance(msg, AIMessage):
return {"role": "assistant", "content": msg.content}
result = {"role": "assistant", "content": msg.content}
if hasattr(msg, "tool_calls") and msg.tool_calls:
result["tool_calls"] = msg.tool_calls
return result
elif isinstance(msg, HumanMessage):
return {"role": "user", "content": msg.content}
elif isinstance(msg, SystemMessage):
return {"role": "system", "content": msg.content}
elif isinstance(msg, ToolMessage):
result = {
"role": "tool",
"content": msg.content,
"tool_call_id": msg.tool_call_id,
}
if hasattr(msg, "name") and msg.name:
result["name"] = msg.name
return result
else: # Handle other message types
role = getattr(msg, "type", "user")
return {"role": role, "content": msg.content}
Expand Down
57 changes: 56 additions & 1 deletion nemoguardrails/rails/llm/llm_flows.co
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ define parallel extension flow process bot tool call
"""Processes tool calls from the bot."""
priority 100

event BotToolCall
event BotToolCalls

$tool_calls = $event.tool_calls

Expand All @@ -130,6 +130,40 @@ define parallel extension flow process bot tool call
create event StartToolCallBotAction(tool_calls=$tool_calls)


define parallel flow process user tool messages
"""Run all the tool input rails on the tool messages."""
priority 200
event UserToolMessages

$tool_messages = $event["tool_messages"]

# If we have tool input rails, we run them, otherwise we just create the user message event
if $config.rails.tool_input.flows
# If we have generation options, we make sure the tool input rails are enabled.
$tool_input_enabled = True
if $generation_options is not None
if $generation_options.rails.tool_input == False
$tool_input_enabled = False

if $tool_input_enabled:
create event StartToolInputRails
event StartToolInputRails

$i = 0
while $i < len($tool_messages)
$tool_message = $tool_messages[$i].content
$tool_name = $tool_messages[$i].name
if "tool_call_id" in $tool_messages[$i]
$tool_call_id = $tool_messages[$i].tool_call_id
else
$tool_call_id = ""

do run tool input rails
$i = $i + 1

create event ToolInputRailsFinished
event ToolInputRailsFinished

define parallel extension flow process bot message
"""Runs the output rails on a bot message."""
priority 100
Expand Down Expand Up @@ -214,3 +248,24 @@ define subflow run tool output rails

# If all went smooth, we remove it.
$triggered_tool_output_rail = None

define subflow run tool input rails
"""Runs all the tool input rails in a sequential order."""
$tool_input_flows = $config.rails.tool_input.flows

$i = 0
while $i < len($tool_input_flows)
# We set the current rail as being triggered.
$triggered_tool_input_rail = $tool_input_flows[$i]

create event StartToolInputRail(flow_id=$triggered_tool_input_rail)
event StartToolInputRail

do $tool_input_flows[$i]
$i = $i + 1

create event ToolInputRailFinished(flow_id=$triggered_tool_input_rail)
event ToolInputRailFinished

# If all went smooth, we remove it.
$triggered_tool_input_rail = None
74 changes: 61 additions & 13 deletions nemoguardrails/rails/llm/llmrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,26 +747,74 @@ def _get_events_for_messages(self, messages: List[dict], state: Any):
)

elif msg["role"] == "assistant":
action_uid = new_uuid()
start_event = new_event_dict(
"StartUtteranceBotAction",
script=msg["content"],
action_uid=action_uid,
)
finished_event = new_event_dict(
"UtteranceBotActionFinished",
final_script=msg["content"],
is_success=True,
action_uid=action_uid,
)
events.extend([start_event, finished_event])
if msg.get("tool_calls"):
events.append(
{"type": "BotToolCalls", "tool_calls": msg["tool_calls"]}
)
else:
action_uid = new_uuid()
start_event = new_event_dict(
"StartUtteranceBotAction",
script=msg["content"],
action_uid=action_uid,
)
finished_event = new_event_dict(
"UtteranceBotActionFinished",
final_script=msg["content"],
is_success=True,
action_uid=action_uid,
)
events.extend([start_event, finished_event])
elif msg["role"] == "context":
events.append({"type": "ContextUpdate", "data": msg["content"]})
elif msg["role"] == "event":
events.append(msg["event"])
elif msg["role"] == "system":
# Handle system messages - convert them to SystemMessage events
events.append({"type": "SystemMessage", "content": msg["content"]})
elif msg["role"] == "tool":
# For the last tool message, create grouped tool event and synthetic UserMessage
if idx == len(messages) - 1:
# Find the original user message for response generation
user_message = None
for prev_msg in reversed(messages[:idx]):
if prev_msg["role"] == "user":
user_message = prev_msg["content"]
break

if user_message:
# If tool input rails are configured, group all tool messages
if self.config.rails.tool_input.flows:
# Collect all tool messages for grouped processing
tool_messages = []
for tool_idx in range(len(messages)):
if messages[tool_idx]["role"] == "tool":
tool_messages.append(
{
"content": messages[tool_idx][
"content"
],
"name": messages[tool_idx].get(
"name", "unknown"
),
"tool_call_id": messages[tool_idx].get(
"tool_call_id", ""
),
}
)

events.append(
{
"type": "UserToolMessages",
"tool_messages": tool_messages,
}
)

else:
events.append(
{"type": "UserMessage", "text": user_message}
)

else:
for idx in range(len(messages)):
msg = messages[idx]
Expand Down
Loading