Skip to content

Add Strands SDK integration for RAG agent training#359

Open
JunjieAraoXiong wants to merge 2 commits intorllm-org:mainfrom
JunjieAraoXiong:strands-sdk-integration
Open

Add Strands SDK integration for RAG agent training#359
JunjieAraoXiong wants to merge 2 commits intorllm-org:mainfrom
JunjieAraoXiong:strands-sdk-integration

Conversation

@JunjieAraoXiong
Copy link

@JunjieAraoXiong JunjieAraoXiong commented Dec 31, 2025

Aim

This PR adds Strands SDK support to rLLM as an alternative to LangGraph for training RAG agents. Strands uses a simpler @tool decorator model rather than LangGraph’s graph-based orchestration, which makes the agent definition more lightweight and easier to reason about. The implementation is based on the existing LangGraph RAG example in examples/sdk/langgraph/ and the documentation at https://rllm-project.readthedocs.io/en/latest/examples/sdk_langgraph_rag/. It also references the Strands tools repository (https://github.com/strands-agents/tools) for compatibility and design alignment.

Changes

This PR introduces a new examples/strands/ directory containing a full Strands-based RAG training pipeline. The main agent implementation (search_agent_strands.py) uses a custom NonStreamingOpenAIModel wrapper to ensure LiteLLM trace capture works correctly, enforces tool-turn budgeting, and handles streaming event conversion. The retrieval layer (retrieve_tool.py) implements async RAG with httpx connection pooling. The training entry point (train_strands_agent.py) integrates HotpotQA with RewardSearchFn, and train_strands_agent.sh provides a Hydra config for 8xGPU RLOO training in fp16. The RAG backend includes an auto-batching FastAPI server (rag/rag_server.py) with multi-GPU FAISS support and a launch script (rag/launch_rag.sh). A README is included for setup and usage.

Legacy Strands files (run_strands.py, strands_workflow.py, gsearch_tool_wrapped.py, .env.example, and the eval/ directory) were removed and replaced by this cleaner implementation.

Bug Fixes

This PR also fixes a Qwen3 tool-calling issue in rllm/integrations/strands.py and filters out Strands-specific kwargs inside rllm/engine/rollout/openai_engine.py to prevent unintended argument propagation.

Design Decisions

Strands hardcodes stream=True, but LiteLLM’s async_post_call_success_hook only fires for non-streaming requests. To preserve tracing compatibility, a NonStreamingOpenAIModel subclass forces stream=False and converts ChatCompletion responses into Strands StreamEvent format. Tool-turn budgeting is enforced using num_tool_turns because Strands consumes messageStop events internally before they reach the event loop. The RAG server uses request auto-batching for embeddings and FAISS search to improve throughput during concurrent rollouts.

Results

Training was conducted using RLOO on Qwen3-4B with 8x H100 GPUs on the HotpotQA dataset. The Strands implementation achieved a +15pp improvement in pass@1 compared to the LangGraph baseline.

Testing

Local testing with the RAG server works as expected. Trajectories are saved correctly, and multi-GPU training completed successfully with positive results.

@JunjieAraoXiong JunjieAraoXiong marked this pull request as ready for review December 31, 2025 07:54
@JunjieAraoXiong JunjieAraoXiong marked this pull request as draft December 31, 2025 07:54
@JunjieAraoXiong JunjieAraoXiong marked this pull request as ready for review February 27, 2026 08:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant