Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions internal/adapter/claude/handler_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,17 @@ func TestHandleClaudeStreamRealtimeToolSafety(t *testing.T) {
if !foundToolUse {
t.Fatalf("expected tool_use block in stream, body=%s", rec.Body.String())
}
foundInputDelta := false
for _, f := range findClaudeFrames(frames, "content_block_delta") {
delta, _ := f.Payload["delta"].(map[string]any)
if delta["type"] == "input_json_delta" && strings.Contains(asString(delta["partial_json"]), `"q":"go"`) {
foundInputDelta = true
break
}
}
if !foundInputDelta {
t.Fatalf("expected input_json_delta with tool arguments, body=%s", rec.Body.String())
}

foundToolUseStop := false
for _, f := range findClaudeFrames(frames, "message_delta") {
Expand Down
3 changes: 3 additions & 0 deletions internal/adapter/claude/standard_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func normalizeClaudeRequest(store ConfigReader, req map[string]any) (claudeNorma
}
finalPrompt := deepseek.MessagesPrepare(toMessageMaps(dsPayload["messages"]))
toolNames := extractClaudeToolNames(toolsRequested)
if len(toolNames) == 0 && len(toolsRequested) > 0 {
toolNames = []string{"__any_tool__"}
}

return claudeNormalizedRequest{
Standard: util.StandardRequest{
Expand Down
10 changes: 0 additions & 10 deletions internal/adapter/claude/stream_runtime_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (

"ds2api/internal/sse"
streamengine "ds2api/internal/stream"
"ds2api/internal/util"
)

type claudeStreamRuntime struct {
Expand Down Expand Up @@ -120,15 +119,6 @@ func (s *claudeStreamRuntime) onParsed(parsed sse.LineResult) streamengine.Parse
if hasUnclosedCodeFence(s.text.String()) {
continue
}
detected := util.ParseToolCalls(s.text.String(), s.toolNames)
if len(detected) > 0 {
s.finalize("tool_use")
return streamengine.ParsedDecision{
ContentSeen: true,
Stop: true,
StopReason: streamengine.StopReason("tool_use_detected"),
}
}
continue
}
s.closeThinkingBlock()
Expand Down
12 changes: 11 additions & 1 deletion internal/adapter/claude/stream_runtime_finalize.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package claude

import (
"encoding/json"
"fmt"
"time"

Expand Down Expand Up @@ -53,14 +54,23 @@ func (s *claudeStreamRuntime) finalize(stopReason string) {
stopReason = "tool_use"
for i, tc := range detected {
idx := s.nextBlockIndex + i
inputJSON, _ := json.Marshal(tc.Input)
s.send("content_block_start", map[string]any{
"type": "content_block_start",
"index": idx,
"content_block": map[string]any{
"type": "tool_use",
"id": fmt.Sprintf("toolu_%d_%d", time.Now().Unix(), idx),
"name": tc.Name,
"input": tc.Input,
"input": map[string]any{},
},
})
s.send("content_block_delta", map[string]any{
"type": "content_block_delta",
"index": idx,
"delta": map[string]any{
"type": "input_json_delta",
"partial_json": string(inputJSON),
},
})
s.send("content_block_stop", map[string]any{
Expand Down
21 changes: 7 additions & 14 deletions internal/adapter/openai/handler_toolcall_format.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,28 +111,21 @@ func filterIncrementalToolCallDeltasByAllowed(deltas []toolCallDelta, allowedNam
if len(deltas) == 0 {
return nil
}
allowed := namesToSet(allowedNames)
if len(allowed) == 0 {
for _, d := range deltas {
if d.Name != "" {
seenNames[d.Index] = "__blocked__"
}
}
return nil
}
out := make([]toolCallDelta, 0, len(deltas))
for _, d := range deltas {
if d.Name != "" {
if _, ok := allowed[d.Name]; !ok {
seenNames[d.Index] = "__blocked__"
continue
if seenNames != nil {
seenNames[d.Index] = d.Name
}
seenNames[d.Index] = d.Name
out = append(out, d)
continue
}
if seenNames == nil {
out = append(out, d)
continue
}
name := strings.TrimSpace(seenNames[d.Index])
if name == "" || name == "__blocked__" {
if name == "" {
continue
}
out = append(out, d)
Expand Down
35 changes: 16 additions & 19 deletions internal/adapter/openai/handler_toolcall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ func TestHandleNonStreamToolCallInterceptsReasonerModel(t *testing.T) {
}
}

func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
Expand All @@ -198,16 +198,13 @@ func TestHandleNonStreamUnknownToolNotIntercepted(t *testing.T) {
out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "stop" {
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
if choice["finish_reason"] != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
if _, ok := msg["tool_calls"]; ok {
t.Fatalf("did not expect tool_calls for unknown schema name, got %#v", msg["tool_calls"])
}
content, _ := msg["content"].(string)
if !strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected unknown tool json to pass through as text, got %#v", content)
toolCalls, _ := msg["tool_calls"].([]any)
if len(toolCalls) != 1 {
t.Fatalf("expected tool_calls for unknown schema name, got %#v", msg["tool_calls"])
}
}

Expand Down Expand Up @@ -413,7 +410,7 @@ func TestHandleStreamReasonerToolCallInterceptsWithoutRawContentLeak(t *testing.
}
}

func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) {
func TestHandleStreamUnknownToolEmitsToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\",\"input\":{\"q\":\"go\"}}]}"}`,
Expand All @@ -428,18 +425,18 @@ func TestHandleStreamUnknownToolDoesNotLeakRawPayload(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta for unknown schema name, body=%s", rec.Body.String())
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta for unknown schema name, body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name: %s", rec.Body.String())
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}

func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) {
func TestHandleStreamUnknownToolNoArgsEmitsToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"not_in_schema\"}]}"}`,
Expand All @@ -454,14 +451,14 @@ func TestHandleStreamUnknownToolNoArgsDoesNotLeakRawPayload(t *testing.T) {
if !done {
t.Fatalf("expected [DONE], body=%s", rec.Body.String())
}
if streamHasToolCallsDelta(frames) {
t.Fatalf("did not expect tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
if !streamHasToolCallsDelta(frames) {
t.Fatalf("expected tool_calls delta for unknown schema name (no args), body=%s", rec.Body.String())
}
if streamHasRawToolJSONContent(frames) {
t.Fatalf("did not expect raw tool_calls json leak for unknown schema name (no args): %s", rec.Body.String())
}
if streamFinishReason(frames) != "stop" {
t.Fatalf("expected finish_reason=stop, body=%s", rec.Body.String())
if streamFinishReason(frames) != "tool_calls" {
t.Fatalf("expected finish_reason=tool_calls, body=%s", rec.Body.String())
}
}

Expand Down
22 changes: 13 additions & 9 deletions internal/adapter/openai/responses_stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ func TestHandleResponsesStreamThinkingAndMixedToolExampleEmitsFunctionCall(t *te
}
}

func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
func TestHandleResponsesStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
Expand All @@ -376,8 +376,8 @@ func TestHandleResponsesStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {

h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, nil, policy, "")
body := rec.Body.String()
if strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("did not expect function_call events for tool_choice=none, body=%s", body)
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected function_call events for tool_choice=none, body=%s", body)
}
}

Expand Down Expand Up @@ -518,7 +518,7 @@ func TestHandleResponsesStreamRequiredMalformedToolPayloadFails(t *testing.T) {
}
}

func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {
func TestHandleResponsesStreamAllowsUnknownToolName(t *testing.T) {
h := &Handler{}
req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
rec := httptest.NewRecorder()
Expand All @@ -539,8 +539,8 @@ func TestHandleResponsesStreamRejectsUnknownToolName(t *testing.T) {

h.handleResponsesStream(rec, req, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, false, []string{"read_file"}, util.DefaultToolChoicePolicy(), "")
body := rec.Body.String()
if strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("did not expect function_call events for unknown tool, body=%s", body)
if !strings.Contains(body, "event: response.function_call_arguments.done") {
t.Fatalf("expected function_call events for unknown tool, body=%s", body)
}
}

Expand Down Expand Up @@ -597,7 +597,7 @@ func TestHandleResponsesNonStreamRequiredToolChoiceIgnoresThinkingToolPayload(t
}
}

func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T) {
func TestHandleResponsesNonStreamToolChoiceNoneStillAllowsFunctionCall(t *testing.T) {
h := &Handler{}
rec := httptest.NewRecorder()
resp := &http.Response{
Expand All @@ -611,16 +611,20 @@ func TestHandleResponsesNonStreamToolChoiceNoneRejectsFunctionCall(t *testing.T)

h.handleResponsesNonStream(rec, resp, "owner-a", "resp_test", "deepseek-chat", "prompt", false, nil, policy, "")
if rec.Code != http.StatusOK {
t.Fatalf("expected 200 for tool_choice=none passthrough text, got %d body=%s", rec.Code, rec.Body.String())
t.Fatalf("expected 200 for tool_choice=none handling, got %d body=%s", rec.Code, rec.Body.String())
}
out := decodeJSONBody(t, rec.Body.String())
output, _ := out["output"].([]any)
foundFunctionCall := false
for _, item := range output {
m, _ := item.(map[string]any)
if m != nil && m["type"] == "function_call" {
t.Fatalf("did not expect function_call output item for tool_choice=none, got %#v", output)
foundFunctionCall = true
}
}
if !foundFunctionCall {
t.Fatalf("expected function_call output item for tool_choice=none, got %#v", output)
}
}

func extractSSEEventPayload(body, targetEvent string) (map[string]any, bool) {
Expand Down
21 changes: 17 additions & 4 deletions internal/adapter/openai/standard_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ func normalizeOpenAIChatRequest(store ConfigReader, req map[string]any, traceID
}
toolPolicy := util.DefaultToolChoicePolicy()
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
passThrough := collectOpenAIChatPassThrough(req)

return util.StandardRequest{
Expand Down Expand Up @@ -74,10 +75,8 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
return util.StandardRequest{}, err
}
finalPrompt, toolNames := buildOpenAIFinalPromptWithPolicy(messagesRaw, req["tools"], traceID, toolPolicy)
if toolPolicy.IsNone() {
toolNames = nil
toolPolicy.Allowed = nil
} else {
toolNames = ensureToolDetectionEnabled(toolNames, req["tools"])
if !toolPolicy.IsNone() {
toolPolicy.Allowed = namesToSet(toolNames)
}
passThrough := collectOpenAIChatPassThrough(req)
Expand All @@ -98,6 +97,20 @@ func normalizeOpenAIResponsesRequest(store ConfigReader, req map[string]any, tra
}, nil
}

func ensureToolDetectionEnabled(toolNames []string, toolsRaw any) []string {
if len(toolNames) > 0 {
return toolNames
}
tools, _ := toolsRaw.([]any)
if len(tools) == 0 {
return toolNames
}
// Keep stream sieve/tool buffering enabled even when client tool schemas
// are malformed or lack explicit names; parsed tool payload names are no
// longer filtered by this list.
return []string{"__any_tool__"}
}

func collectOpenAIChatPassThrough(req map[string]any) map[string]any {
out := map[string]any{}
for _, k := range []string{
Expand Down
6 changes: 3 additions & 3 deletions internal/adapter/openai/standard_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceForcedUndeclaredFails(t *testi
}
}

func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T) {
func TestNormalizeOpenAIResponsesRequestToolChoiceNoneKeepsToolDetectionEnabled(t *testing.T) {
store := newEmptyStoreForNormalizeTest(t)
req := map[string]any{
"model": "gpt-4o",
Expand All @@ -174,7 +174,7 @@ func TestNormalizeOpenAIResponsesRequestToolChoiceNoneDisablesTools(t *testing.T
if n.ToolChoice.Mode != util.ToolChoiceNone {
t.Fatalf("expected tool choice mode none, got %q", n.ToolChoice.Mode)
}
if len(n.ToolNames) != 0 {
t.Fatalf("expected no tool names when tool_choice=none, got %#v", n.ToolNames)
if len(n.ToolNames) == 0 {
t.Fatalf("expected tool detection sentinel when tool_choice=none, got %#v", n.ToolNames)
}
}
22 changes: 5 additions & 17 deletions internal/js/chat-stream/toolcall_policy.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ const {

function resolveToolcallPolicy(prepBody, payloadTools) {
const preparedToolNames = normalizePreparedToolNames(prepBody && prepBody.tool_names);
const toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
let toolNames = preparedToolNames.length > 0 ? preparedToolNames : extractToolNames(payloadTools);
if (toolNames.length === 0 && Array.isArray(payloadTools) && payloadTools.length > 0) {
toolNames = ['__any_tool__'];
}
const featureMatchEnabled = boolDefaultTrue(prepBody && prepBody.toolcall_feature_match);
const emitEarlyToolDeltas = featureMatchEnabled && boolDefaultTrue(prepBody && prepBody.toolcall_early_emit_high);
return {
Expand Down Expand Up @@ -76,17 +79,6 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName
return [];
}
const seen = seenNames instanceof Map ? seenNames : new Map();
const allowed = new Set((allowedNames || []).filter((name) => asString(name) !== ''));
if (allowed.size === 0) {
for (const d of deltas) {
if (d && typeof d === 'object' && asString(d.name)) {
const index = Number.isInteger(d.index) ? d.index : 0;
seen.set(index, '__blocked__');
}
}
return [];
}

const out = [];
for (const d of deltas) {
if (!d || typeof d !== 'object') {
Expand All @@ -95,16 +87,12 @@ function filterIncrementalToolCallDeltasByAllowed(deltas, allowedNames, seenName
const index = Number.isInteger(d.index) ? d.index : 0;
const name = asString(d.name);
if (name) {
if (!allowed.has(name)) {
seen.set(index, '__blocked__');
continue;
}
seen.set(index, name);
out.push(d);
continue;
}
const existing = asString(seen.get(index));
if (!existing || existing === '__blocked__') {
if (!existing) {
continue;
}
out.push(d);
Expand Down
Loading
Loading