Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 132 additions & 9 deletions internal/adapter/claude/handler_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,85 @@ func TestNormalizeClaudeMessagesToolResult(t *testing.T) {
},
}
got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected one normalized message, got %d", len(got))
}
m := got[0].(map[string]any)
if m["role"] != "tool" {
t.Fatalf("expected tool role preserved, got %#v", m["role"])
}
content, _ := m["content"].(string)
if content != "tool output" {
t.Fatalf("expected raw tool output content preserved, got %q", content)
}
}

func TestNormalizeClaudeMessagesToolUseToAssistantToolCalls(t *testing.T) {
msgs := []any{
map[string]any{
"role": "assistant",
"content": []any{
map[string]any{
"type": "tool_use",
"id": "call_1",
"name": "search_web",
"input": map[string]any{"query": "latest"},
},
},
},
}

got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected one normalized tool-call message, got %d", len(got))
}
m := got[0].(map[string]any)
if m["role"] != "assistant" {
t.Fatalf("expected assistant role, got %#v", m["role"])
}
tc, _ := m["tool_calls"].([]any)
if len(tc) != 1 {
t.Fatalf("expected one tool call, got %#v", m["tool_calls"])
}
call, _ := tc[0].(map[string]any)
if call["id"] != "call_1" {
t.Fatalf("expected call id preserved, got %#v", call)
}
content, _ := m["content"].(string)
if !containsStr(content, "search_web") || !containsStr(content, `"arguments":"{\"query\":\"latest\"}"`) {
t.Fatalf("expected assistant content to include serialized tool call for prompt roundtrip, got %q", content)
}
}

func TestNormalizeClaudeMessagesDoesNotPromoteUserToolUse(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "tool_use",
"id": "call_unsafe",
"name": "dangerous_tool",
"input": map[string]any{"value": "x"},
},
},
},
}

got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected one normalized message, got %d", len(got))
}
m := got[0].(map[string]any)
if m["role"] != "user" {
t.Fatalf("expected user role preserved, got %#v", m["role"])
}
if _, ok := m["tool_calls"]; ok {
t.Fatalf("expected no tool_calls promotion for user message, got %#v", m["tool_calls"])
}
content, _ := m["content"].(string)
if !strings.Contains(content, "[TOOL_RESULT_HISTORY]") || !strings.Contains(content, "content: tool output") {
t.Fatalf("expected serialized tool result marker, got %q", content)
if !containsStr(content, `"type":"tool_use"`) || !containsStr(content, "dangerous_tool") {
t.Fatalf("expected raw tool_use block preserved in user content, got %q", content)
}
}

Expand Down Expand Up @@ -87,15 +162,63 @@ func TestNormalizeClaudeMessagesMixedContentBlocks(t *testing.T) {
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "Hello"},
map[string]any{"type": "image", "source": "data:..."},
map[string]any{"type": "image", "source": map[string]any{"type": "base64", "data": strings.Repeat("A", 2048)}},
map[string]any{"type": "text", "text": "World"},
},
},
}
got := normalizeClaudeMessages(msgs)
m := got[0].(map[string]any)
if m["content"] != "Hello\nWorld" {
t.Fatalf("expected only text parts joined, got %q", m["content"])
content, _ := m["content"].(string)
if !containsStr(content, "Hello") || !containsStr(content, "World") || !containsStr(content, `"type":"image"`) {
t.Fatalf("expected text plus non-text block marker preserved, got %q", content)
}
if !containsStr(content, omittedBinaryMarker) {
t.Fatalf("expected binary payload omitted marker, got %q", content)
}
if containsStr(content, strings.Repeat("A", 100)) {
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
}
}

func TestNormalizeClaudeMessagesToolResultNonTextPayloadStringified(t *testing.T) {
msgs := []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{
"type": "tool_result",
"tool_use_id": "call_image_1",
"name": "vision_tool",
"content": []any{
map[string]any{"type": "text", "text": "image analysis"},
map[string]any{
"type": "image",
"source": map[string]any{"type": "base64", "media_type": "image/png", "data": strings.Repeat("B", 2048)},
},
},
},
},
},
}

got := normalizeClaudeMessages(msgs)
if len(got) != 1 {
t.Fatalf("expected one normalized message, got %d", len(got))
}
m := got[0].(map[string]any)
if m["role"] != "tool" {
t.Fatalf("expected tool role, got %#v", m["role"])
}
content, _ := m["content"].(string)
if !containsStr(content, `"type":"tool_result"`) || !containsStr(content, `"type":"image"`) {
t.Fatalf("expected non-text tool_result payload to be JSON stringified, got %q", content)
}
if !containsStr(content, omittedBinaryMarker) {
t.Fatalf("expected binary data to be sanitized with omitted marker, got %q", content)
}
if containsStr(content, strings.Repeat("B", 100)) {
t.Fatalf("expected raw base64 payload not to be included, got %q", content)
}
}

Expand Down Expand Up @@ -128,11 +251,11 @@ func TestBuildClaudeToolPromptSingleTool(t *testing.T) {
if !containsStr(prompt, "tool_use") {
t.Fatalf("expected tool_use instruction in prompt")
}
if !containsStr(prompt, "Never output [TOOL_CALL_HISTORY] or [TOOL_RESULT_HISTORY] markers yourself") {
t.Fatalf("expected marker guard instruction in prompt")
if containsStr(prompt, "TOOL_CALL_HISTORY") || containsStr(prompt, "TOOL_RESULT_HISTORY") {
t.Fatalf("expected legacy tool history markers removed from prompt")
}
if containsStr(prompt, "tool_calls") {
t.Fatalf("expected prompt to avoid tool_calls JSON instruction")
if !containsStr(prompt, "Do not print tool-call JSON in text") {
t.Fatalf("expected prompt to keep no tool-call-json instruction")
}
}

Expand Down
160 changes: 139 additions & 21 deletions internal/adapter/claude/handler_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,58 @@ func normalizeClaudeMessages(messages []any) []any {
if !ok {
continue
}
copied := cloneMap(msg)
role := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", msg["role"])))
switch content := msg["content"].(type) {
case []any:
parts := make([]string, 0, len(content))
textParts := make([]string, 0, len(content))
flushText := func() {
if len(textParts) == 0 {
return
}
out = append(out, map[string]any{
"role": role,
"content": strings.Join(textParts, "\n"),
})
textParts = textParts[:0]
}
for _, block := range content {
b, ok := block.(map[string]any)
if !ok {
continue
}
typeStr, _ := b["type"].(string)
if typeStr == "text" {
typeStr := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", b["type"])))
switch typeStr {
case "text":
if t, ok := b["text"].(string); ok {
parts = append(parts, t)
textParts = append(textParts, t)
}
case "tool_use":
if role == "assistant" {
flushText()
if toolMsg := normalizeClaudeToolUseToAssistant(b); toolMsg != nil {
out = append(out, toolMsg)
}
continue
}
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
textParts = append(textParts, raw)
}
case "tool_result":
flushText()
if toolMsg := normalizeClaudeToolResultToToolMessage(b); toolMsg != nil {
out = append(out, toolMsg)
}
default:
if raw := strings.TrimSpace(formatClaudeUnknownBlockForPrompt(b)); raw != "" {
textParts = append(textParts, raw)
}
}
if typeStr == "tool_result" {
parts = append(parts, formatClaudeToolResultForPrompt(b))
}
}
copied["content"] = strings.Join(parts, "\n")
flushText()
default:
copied := cloneMap(msg)
out = append(out, copied)
}
out = append(out, copied)
}
return out
}
Expand All @@ -52,9 +82,8 @@ func buildClaudeToolPrompt(tools []any) string {
}
parts = append(parts,
"When you need a tool, respond with Claude-native tool use (tool_use) using the provided tool schema. Do not print tool-call JSON in text.",
"History markers in conversation: [TOOL_CALL_HISTORY]...[/TOOL_CALL_HISTORY] are your previous tool calls; [TOOL_RESULT_HISTORY]...[/TOOL_RESULT_HISTORY] are runtime tool outputs, not user input.",
"After a valid [TOOL_RESULT_HISTORY], continue with final answer instead of repeating the same call unless required fields are still missing.",
"Never output [TOOL_CALL_HISTORY] or [TOOL_RESULT_HISTORY] markers yourself; they are system-side context only.",
"Tool roundtrip context is included directly in the conversation messages (assistant tool_use/tool_calls and tool results).",
"After receiving a valid tool result, continue with final answer instead of repeating the same call unless required fields are still missing.",
)
return strings.Join(parts, "\n\n")
}
Expand All @@ -63,22 +92,111 @@ func formatClaudeToolResultForPrompt(block map[string]any) string {
if block == nil {
return ""
}
payload := map[string]any{
"type": "tool_result",
"content": block["content"],
}
if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"])); toolCallID != "" {
payload["tool_call_id"] = toolCallID
} else if toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"])); toolCallID != "" {
payload["tool_call_id"] = toolCallID
}
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
payload["name"] = name
}
b, err := json.Marshal(payload)
if err != nil {
return strings.TrimSpace(fmt.Sprintf("%v", payload))
}
return string(b)
}

func normalizeClaudeToolUseToAssistant(block map[string]any) map[string]any {
if block == nil {
return nil
}
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
if name == "" {
return nil
}
callID := strings.TrimSpace(fmt.Sprintf("%v", block["id"]))
if callID == "" {
callID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
}
if callID == "" {
callID = "call_claude"
}
arguments := block["input"]
if arguments == nil {
arguments = map[string]any{}
}
argsJSON, err := json.Marshal(arguments)
if err != nil || len(argsJSON) == 0 {
argsJSON = []byte("{}")
}
toolCalls := []any{
map[string]any{
"id": callID,
"type": "function",
"function": map[string]any{
"name": name,
"arguments": string(argsJSON),
},
},
}
return map[string]any{
"role": "assistant",
"content": marshalCompactJSON(toolCalls),
"tool_calls": toolCalls,
}
}

func normalizeClaudeToolResultToToolMessage(block map[string]any) map[string]any {
if block == nil {
return nil
}
toolCallID := strings.TrimSpace(fmt.Sprintf("%v", block["tool_use_id"]))
if toolCallID == "" {
toolCallID = strings.TrimSpace(fmt.Sprintf("%v", block["tool_call_id"]))
}
if toolCallID == "" {
toolCallID = "unknown"
toolCallID = "call_claude"
}
name := strings.TrimSpace(fmt.Sprintf("%v", block["name"]))
if name == "" {
name = "unknown"
out := map[string]any{
"role": "tool",
"tool_call_id": toolCallID,
"content": normalizeClaudeToolResultContent(block["content"]),
}
if name := strings.TrimSpace(fmt.Sprintf("%v", block["name"])); name != "" {
out["name"] = name
}
return out
}

func normalizeClaudeToolResultContent(content any) any {
if text, ok := content.(string); ok {
return text
}
payload := map[string]any{
"type": "tool_result",
"content": content,
}
b, err := json.Marshal(sanitizeClaudeBlockForPrompt(payload))
if err != nil {
return strings.TrimSpace(fmt.Sprintf("%v", content))
}
return string(b)
}

func formatClaudeBlockRaw(block map[string]any) string {
if block == nil {
return ""
}
content := strings.TrimSpace(fmt.Sprintf("%v", block["content"]))
if content == "" {
content = "null"
b, err := json.Marshal(block)
if err != nil {
return strings.TrimSpace(fmt.Sprintf("%v", block))
}
return fmt.Sprintf("[TOOL_RESULT_HISTORY]\nstatus: already_returned\norigin: tool_runtime\nnot_user_input: true\ntool_call_id: %s\nname: %s\ncontent: %s\n[/TOOL_RESULT_HISTORY]", toolCallID, name, content)
return string(b)
}

func hasSystemMessage(messages []any) bool {
Expand Down
Loading
Loading