Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions examples/sglang/multiturn_generate_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os

import sglang as sgl
from renderers.configs import Qwen35RendererConfig
from renderers.gpt_oss import GptOssRenderer
from renderers.qwen35 import Qwen35Renderer
from transformers import AutoTokenizer
Expand Down Expand Up @@ -52,7 +53,9 @@
def make_renderer(model: str, enable_thinking: bool | None):
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=False)
if model.startswith("Qwen/Qwen3.5-"):
return Qwen35Renderer(tokenizer, enable_thinking=enable_thinking)
return Qwen35Renderer(
tokenizer, Qwen35RendererConfig(enable_thinking=enable_thinking)
)
if model == "openai/gpt-oss-20b":
return GptOssRenderer(tokenizer)
raise ValueError(f"unsupported demo model: {model}")
Expand All @@ -62,8 +65,9 @@ def print_parsed(label: str, turn: str, parsed) -> None:
print(f"\n[{label}] {turn}")
if parsed.reasoning_content:
print(f"reasoning: {parsed.reasoning_content[:240]}")
if parsed.tool_calls:
print(f"tool_calls: {json.dumps(parsed.tool_calls, ensure_ascii=False)}")
for tc in parsed.tool_calls:
# ``parse_response`` returns ``ParsedToolCall`` dataclasses, not dicts.
print(f"tool_call: {tc.name}({tc.arguments}) [{tc.status.value}]")
if parsed.content:
print(f"content: {parsed.content}")

Expand Down Expand Up @@ -141,21 +145,33 @@ def main() -> None:
if parsed1.reasoning_content:
assistant["reasoning_content"] = parsed1.reasoning_content
if parsed1.tool_calls:
assistant["tool_calls"] = parsed1.tool_calls
# Convert the parsed dataclasses back to OpenAI-format tool_calls.
assistant["tool_calls"] = [
{
"id": tc.id or f"call_{idx}",
"type": "function",
"function": {
"name": tc.name,
"arguments": tc.arguments
if isinstance(tc.arguments, str)
else json.dumps(tc.arguments),
},
}
for idx, tc in enumerate(parsed1.tool_calls)
]
messages.append(assistant)

if parsed1.tool_calls:
new_messages = []
for idx, tool_call in enumerate(parsed1.tool_calls):
fn = tool_call.get("function") or tool_call
tool_args = fn.get("arguments") or {}
tool_args = tool_call.arguments or {}
if isinstance(tool_args, str):
tool_args = json.loads(tool_args)
new_messages.append(
{
"role": "tool",
"tool_call_id": tool_call.get("id", f"call_{idx}"),
"name": fn.get("name", "multiply"),
"tool_call_id": tool_call.id or f"call_{idx}",
"name": tool_call.name or "multiply",
"content": json.dumps(
{"result": int(tool_args["a"]) * int(tool_args["b"])}
),
Expand All @@ -167,11 +183,14 @@ def main() -> None:
]

# Turn 2: bridge extends prompt_ids + completion1 exactly.
bridged_ids = renderer.bridge_to_next_turn(
# ``bridge_to_next_turn`` returns a ``RenderedTokens`` (or None); the
# extended id stream is on ``.token_ids``.
bridged = renderer.bridge_to_next_turn(
prompt_ids, completion1, new_messages, tools=TOOLS
)
if bridged_ids is None:
if bridged is None:
raise RuntimeError("bridge_to_next_turn returned None")
bridged_ids = bridged.token_ids
assert bridged_ids[: len(prompt_ids) + len(completion1)] == (
prompt_ids + completion1
)
Expand Down
39 changes: 29 additions & 10 deletions examples/sglang/online_multiturn_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@

import httpx
from renderers.base import Renderer
from renderers.configs import Qwen35RendererConfig
from renderers.gpt_oss import GptOssRenderer
from renderers.qwen35 import Qwen35Renderer
from transformers import AutoTokenizer
Expand Down Expand Up @@ -71,7 +72,9 @@
def make_renderer(model: str, enable_thinking: bool | None) -> Renderer:
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=False)
if model.startswith("Qwen/Qwen3.5-"):
return Qwen35Renderer(tokenizer, enable_thinking=enable_thinking)
return Qwen35Renderer(
tokenizer, Qwen35RendererConfig(enable_thinking=enable_thinking)
)
if model == "openai/gpt-oss-20b":
return GptOssRenderer(tokenizer)
raise ValueError(f"unsupported demo model: {model}")
Expand Down Expand Up @@ -116,8 +119,9 @@ def print_parsed(label: str, turn: str, parsed) -> None:
print(f"\n[{label}] {turn}")
if parsed.reasoning_content:
print(f"reasoning: {parsed.reasoning_content[:240]}")
if parsed.tool_calls:
print(f"tool_calls: {json.dumps(parsed.tool_calls, ensure_ascii=False)}")
for tc in parsed.tool_calls:
# ``parse_response`` returns ``ParsedToolCall`` dataclasses, not dicts.
print(f"tool_call: {tc.name}({tc.arguments}) [{tc.status.value}]")
if parsed.content:
print(f"content: {parsed.content}")

Expand Down Expand Up @@ -164,21 +168,33 @@ async def run_one(
if parsed1.reasoning_content:
assistant["reasoning_content"] = parsed1.reasoning_content
if parsed1.tool_calls:
assistant["tool_calls"] = parsed1.tool_calls
# Convert the parsed dataclasses back to OpenAI-format tool_calls.
assistant["tool_calls"] = [
{
"id": tc.id or f"call_{idx}",
"type": "function",
"function": {
"name": tc.name,
"arguments": tc.arguments
if isinstance(tc.arguments, str)
else json.dumps(tc.arguments),
},
}
for idx, tc in enumerate(parsed1.tool_calls)
]
messages.append(assistant)

if parsed1.tool_calls:
new_messages: list[dict[str, Any]] = []
for idx, tool_call in enumerate(parsed1.tool_calls):
fn = tool_call.get("function") or tool_call
tool_args = fn.get("arguments") or {}
tool_args = tool_call.arguments or {}
if isinstance(tool_args, str):
tool_args = json.loads(tool_args)
new_messages.append(
{
"role": "tool",
"tool_call_id": tool_call.get("id", f"call_{idx}"),
"name": fn.get("name", "multiply"),
"tool_call_id": tool_call.id or f"call_{idx}",
"name": tool_call.name or "multiply",
"content": json.dumps(
{"result": int(tool_args["a"]) * int(tool_args["b"])}
),
Expand All @@ -190,11 +206,14 @@ async def run_one(
]

# Turn 2: bridge extends prompt_ids + completion1 exactly.
bridged_ids = renderer.bridge_to_next_turn(
# ``bridge_to_next_turn`` returns a ``RenderedTokens`` (or None); the
# extended id stream is on ``.token_ids``.
bridged = renderer.bridge_to_next_turn(
prompt_ids, completion1, new_messages, tools=TOOLS
)
if bridged_ids is None:
if bridged is None:
raise RuntimeError("bridge_to_next_turn returned None")
bridged_ids = bridged.token_ids
assert bridged_ids[: len(prompt_ids) + len(completion1)] == (
prompt_ids + completion1
)
Expand Down
39 changes: 29 additions & 10 deletions examples/tinker/multiturn_generate_tinker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os

import tinker
from renderers.configs import Qwen35RendererConfig
from renderers.gpt_oss import GptOssRenderer
from renderers.qwen35 import Qwen35Renderer
from tinker import types
Expand Down Expand Up @@ -53,7 +54,9 @@
def make_renderer(model: str, enable_thinking: bool | None):
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=False)
if model.startswith("Qwen/Qwen3.5-"):
return Qwen35Renderer(tokenizer, enable_thinking=enable_thinking)
return Qwen35Renderer(
tokenizer, Qwen35RendererConfig(enable_thinking=enable_thinking)
)
if model == "openai/gpt-oss-20b":
return GptOssRenderer(tokenizer)
raise ValueError(f"unsupported demo model: {model}")
Expand All @@ -63,8 +66,9 @@ def print_parsed(label: str, turn: str, parsed) -> None:
print(f"\n[{label}] {turn}")
if parsed.reasoning_content:
print(f"reasoning: {parsed.reasoning_content[:240]}")
if parsed.tool_calls:
print(f"tool_calls: {json.dumps(parsed.tool_calls, ensure_ascii=False)}")
for tc in parsed.tool_calls:
# ``parse_response`` returns ``ParsedToolCall`` dataclasses, not dicts.
print(f"tool_call: {tc.name}({tc.arguments}) [{tc.status.value}]")
if parsed.content:
print(f"content: {parsed.content}")

Expand Down Expand Up @@ -131,21 +135,33 @@ async def main() -> None:
if parsed1.reasoning_content:
assistant["reasoning_content"] = parsed1.reasoning_content
if parsed1.tool_calls:
assistant["tool_calls"] = parsed1.tool_calls
# Convert the parsed dataclasses back to OpenAI-format tool_calls.
assistant["tool_calls"] = [
{
"id": tc.id or f"call_{idx}",
"type": "function",
"function": {
"name": tc.name,
"arguments": tc.arguments
if isinstance(tc.arguments, str)
else json.dumps(tc.arguments),
},
}
for idx, tc in enumerate(parsed1.tool_calls)
]
messages.append(assistant)

if parsed1.tool_calls:
new_messages = []
for idx, tool_call in enumerate(parsed1.tool_calls):
fn = tool_call.get("function") or tool_call
tool_args = fn.get("arguments") or {}
tool_args = tool_call.arguments or {}
if isinstance(tool_args, str):
tool_args = json.loads(tool_args)
new_messages.append(
{
"role": "tool",
"tool_call_id": tool_call.get("id", f"call_{idx}"),
"name": fn.get("name", "multiply"),
"tool_call_id": tool_call.id or f"call_{idx}",
"name": tool_call.name or "multiply",
"content": json.dumps(
{"result": int(tool_args["a"]) * int(tool_args["b"])}
),
Expand All @@ -157,11 +173,14 @@ async def main() -> None:
]

# Turn 2: bridge extends prompt_ids + completion1 exactly.
bridged_ids = renderer.bridge_to_next_turn(
# ``bridge_to_next_turn`` returns a ``RenderedTokens`` (or None); the
# extended id stream is on ``.token_ids``.
bridged = renderer.bridge_to_next_turn(
prompt_ids, completion1, new_messages, tools=TOOLS
)
if bridged_ids is None:
if bridged is None:
raise RuntimeError("bridge_to_next_turn returned None")
bridged_ids = bridged.token_ids
assert bridged_ids[: len(prompt_ids) + len(completion1)] == (
prompt_ids + completion1
)
Expand Down
38 changes: 28 additions & 10 deletions examples/transformers/multiturn_generate_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from renderers.configs import Qwen35RendererConfig
from renderers.gpt_oss import GptOssRenderer
from renderers.qwen35 import Qwen35Renderer

Expand Down Expand Up @@ -55,7 +56,8 @@
def make_renderer(model: str, enable_thinking: bool | None):
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=False)
if model.startswith("Qwen/Qwen3.5-"):
return Qwen35Renderer(tokenizer, enable_thinking=enable_thinking), tokenizer
config = Qwen35RendererConfig(enable_thinking=enable_thinking)
return Qwen35Renderer(tokenizer, config), tokenizer
if model == "openai/gpt-oss-20b":
return GptOssRenderer(tokenizer), tokenizer
raise ValueError(f"unsupported demo model: {model}")
Expand All @@ -65,8 +67,9 @@ def print_parsed(label: str, turn: str, parsed) -> None:
print(f"\n[{label}] {turn}")
if parsed.reasoning_content:
print(f"reasoning: {parsed.reasoning_content[:240]}")
if parsed.tool_calls:
print(f"tool_calls: {json.dumps(parsed.tool_calls, ensure_ascii=False)}")
for tc in parsed.tool_calls:
# ``parse_response`` returns ``ParsedToolCall`` dataclasses, not dicts.
print(f"tool_call: {tc.name}({tc.arguments}) [{tc.status.value}]")
if parsed.content:
print(f"content: {parsed.content}")

Expand Down Expand Up @@ -139,21 +142,33 @@ def main() -> None:
if parsed1.reasoning_content:
assistant["reasoning_content"] = parsed1.reasoning_content
if parsed1.tool_calls:
assistant["tool_calls"] = parsed1.tool_calls
# Convert the parsed dataclasses back to OpenAI-format tool_calls.
assistant["tool_calls"] = [
{
"id": tc.id or f"call_{idx}",
"type": "function",
"function": {
"name": tc.name,
"arguments": tc.arguments
if isinstance(tc.arguments, str)
else json.dumps(tc.arguments),
},
}
for idx, tc in enumerate(parsed1.tool_calls)
]
messages.append(assistant)

if parsed1.tool_calls:
new_messages = []
for idx, tool_call in enumerate(parsed1.tool_calls):
fn = tool_call.get("function") or tool_call
tool_args = fn.get("arguments") or {}
tool_args = tool_call.arguments or {}
if isinstance(tool_args, str):
tool_args = json.loads(tool_args)
new_messages.append(
{
"role": "tool",
"tool_call_id": tool_call.get("id", f"call_{idx}"),
"name": fn.get("name", "multiply"),
"tool_call_id": tool_call.id or f"call_{idx}",
"name": tool_call.name or "multiply",
"content": json.dumps(
{"result": int(tool_args["a"]) * int(tool_args["b"])}
),
Expand All @@ -165,11 +180,14 @@ def main() -> None:
]

# Turn 2: bridge extends prompt_ids + completion1 exactly.
bridged_ids = renderer.bridge_to_next_turn(
# ``bridge_to_next_turn`` returns a ``RenderedTokens`` (or None); the
# extended id stream is on ``.token_ids``.
bridged = renderer.bridge_to_next_turn(
prompt_ids, completion1, new_messages, tools=TOOLS
)
if bridged_ids is None:
if bridged is None:
raise RuntimeError("bridge_to_next_turn returned None")
bridged_ids = bridged.token_ids
assert bridged_ids[: len(prompt_ids) + len(completion1)] == (
prompt_ids + completion1
)
Expand Down
Loading
Loading