Skip to content

feat(anthropic): Add support for streaming thinking events #2800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1226,8 +1226,11 @@ public record ContentBlockStartEvent(

@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type",
visible = true)
@JsonSubTypes({ @JsonSubTypes.Type(value = ContentBlockToolUse.class, name = "tool_use"),
@JsonSubTypes.Type(value = ContentBlockText.class, name = "text") })
@JsonSubTypes({
@JsonSubTypes.Type(value = ContentBlockToolUse.class, name = "tool_use"),
@JsonSubTypes.Type(value = ContentBlockText.class, name = "text"),
@JsonSubTypes.Type(value = ContentBlockThinking.class, name = "thinking")
})
public interface ContentBlockBody {
String type();
}
Expand Down Expand Up @@ -1257,6 +1260,19 @@ public record ContentBlockText(
@JsonProperty("type") String type,
@JsonProperty("text") String text) implements ContentBlockBody {
}

/**
* Thinking content block.
* @param type The content block type.
* @param thinking The thinking content.
*/
@JsonInclude(Include.NON_NULL)
public record ContentBlockThinking(
@JsonProperty("type") String type,
@JsonProperty("thinking") String thinking,
@JsonProperty("signature") String signature) implements ContentBlockBody {
}

}
// @formatter:on

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,32 @@
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaJson;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaText;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaThinking;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockDeltaEvent.ContentBlockDeltaSignature;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockText;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockToolUse;
import org.springframework.ai.anthropic.api.AnthropicApi.ContentBlockStartEvent.ContentBlockThinking;
import org.springframework.ai.anthropic.api.AnthropicApi.EventType;
import org.springframework.ai.anthropic.api.AnthropicApi.MessageDeltaEvent;
import org.springframework.ai.anthropic.api.AnthropicApi.MessageStartEvent;
import org.springframework.ai.anthropic.api.AnthropicApi.Role;
import org.springframework.ai.anthropic.api.AnthropicApi.StreamEvent;
import org.springframework.ai.anthropic.api.AnthropicApi.ToolUseAggregationEvent;
import org.springframework.ai.anthropic.api.AnthropicApi.Usage;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

/**
* Helper class to support streaming function calling.
* Helper class to support streaming function calling and thinking events.
* <p>
* It can merge the streamed {@link StreamEvent} chunks in case of function calling
* message.
* message. It passes through other events like text, thinking, and signature deltas.
*
* @author Mariusz Bernacki
* @author Christian Tzolov
* @author Jihoon Kim
* @author Alexandros Pappas
* @since 1.0.0
*/
public class StreamHelper {
Expand All @@ -61,13 +64,16 @@ public boolean isToolUseStart(StreamEvent event) {
}

public boolean isToolUseFinish(StreamEvent event) {

if (event == null || event.type() == null || event.type() != EventType.CONTENT_BLOCK_STOP) {
return false;
}
return true;
// Tool use streaming sequence ends with a CONTENT_BLOCK_STOP event.
// The logic relies on the state machine (isInsideTool flag) managed in
// chatCompletionStream to know if this stop event corresponds to a tool use.
return event != null && event.type() != null && event.type() == EventType.CONTENT_BLOCK_STOP;
}

/**
* Merge the tool‑use related streaming events into one aggregate event so that the
* upper layers see a single ContentBlock with the full JSON input.
*/
public StreamEvent mergeToolUseEvents(StreamEvent previousEvent, StreamEvent event) {

ToolUseAggregationEvent eventAggregator = (ToolUseAggregationEvent) previousEvent;
Expand All @@ -76,8 +82,7 @@ public StreamEvent mergeToolUseEvents(StreamEvent previousEvent, StreamEvent eve
ContentBlockStartEvent contentBlockStart = (ContentBlockStartEvent) event;

if (ContentBlock.Type.TOOL_USE.getValue().equals(contentBlockStart.contentBlock().type())) {
ContentBlockStartEvent.ContentBlockToolUse cbToolUse = (ContentBlockToolUse) contentBlockStart
.contentBlock();
ContentBlockToolUse cbToolUse = (ContentBlockToolUse) contentBlockStart.contentBlock();

return eventAggregator.withIndex(contentBlockStart.index())
.withId(cbToolUse.id())
Expand All @@ -102,6 +107,14 @@ else if (event.type() == EventType.CONTENT_BLOCK_STOP) {
return event;
}

/**
* Converts a raw {@link StreamEvent} potentially containing tool use aggregates or
* other block types (text, thinking) into a {@link ChatCompletionResponse} chunk.
* @param event The incoming StreamEvent.
* @param contentBlockReference Holds the state of the response being built across
* multiple events.
* @return A ChatCompletionResponse representing the processed chunk.
*/
public ChatCompletionResponse eventToChatCompletionResponse(StreamEvent event,
AtomicReference<ChatCompletionResponseBuilder> contentBlockReference) {

Expand Down Expand Up @@ -135,28 +148,41 @@ else if (event.type().equals(EventType.TOOL_USE_AGGREGATE)) {
else if (event.type().equals(EventType.CONTENT_BLOCK_START)) {
ContentBlockStartEvent contentBlockStartEvent = (ContentBlockStartEvent) event;

Assert.isTrue(contentBlockStartEvent.contentBlock().type().equals("text"),
"The json content block should have been aggregated. Unsupported content block type: "
+ contentBlockStartEvent.contentBlock().type());

ContentBlockText contentBlockText = (ContentBlockText) contentBlockStartEvent.contentBlock();
ContentBlock contentBlock = new ContentBlock(Type.TEXT, null, contentBlockText.text(),
contentBlockStartEvent.index());
contentBlockReference.get().withType(event.type().name()).withContent(List.of(contentBlock));
if (contentBlockStartEvent.contentBlock() instanceof ContentBlockText textBlock) {
ContentBlock cb = new ContentBlock(Type.TEXT, null, textBlock.text(), contentBlockStartEvent.index());
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else if (contentBlockStartEvent.contentBlock() instanceof ContentBlockThinking thinkingBlock) {
ContentBlock cb = new ContentBlock(Type.THINKING, null, null, contentBlockStartEvent.index(), null,
null, null, null, null, null, thinkingBlock.thinking(), null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else {
throw new IllegalArgumentException(
"Unsupported content block type: " + contentBlockStartEvent.contentBlock().type());
}
}
else if (event.type().equals(EventType.CONTENT_BLOCK_DELTA)) {

ContentBlockDeltaEvent contentBlockDeltaEvent = (ContentBlockDeltaEvent) event;

Assert.isTrue(contentBlockDeltaEvent.delta().type().equals("text_delta"),
"The json content block delta should have been aggregated. Unsupported content block type: "
+ contentBlockDeltaEvent.delta().type());

ContentBlockDeltaText deltaTxt = (ContentBlockDeltaText) contentBlockDeltaEvent.delta();

var contentBlock = new ContentBlock(Type.TEXT_DELTA, null, deltaTxt.text(), contentBlockDeltaEvent.index());

contentBlockReference.get().withType(event.type().name()).withContent(List.of(contentBlock));
if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaText txt) {
ContentBlock cb = new ContentBlock(Type.TEXT_DELTA, null, txt.text(), contentBlockDeltaEvent.index());
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaThinking thinking) {
ContentBlock cb = new ContentBlock(Type.THINKING_DELTA, null, null, contentBlockDeltaEvent.index(),
null, null, null, null, null, null, thinking.thinking(), null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else if (contentBlockDeltaEvent.delta() instanceof ContentBlockDeltaSignature sig) {
ContentBlock cb = new ContentBlock(Type.SIGNATURE_DELTA, null, null, contentBlockDeltaEvent.index(),
null, null, null, null, null, sig.signature(), null, null);
contentBlockReference.get().withType(event.type().name()).withContent(List.of(cb));
}
else {
throw new IllegalArgumentException(
"Unsupported content block delta type: " + contentBlockDeltaEvent.delta().type());
}
}
else if (event.type().equals(EventType.MESSAGE_DELTA)) {

Expand All @@ -173,21 +199,26 @@ else if (event.type().equals(EventType.MESSAGE_DELTA)) {
}

if (messageDeltaEvent.usage() != null) {
var totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
Usage totalUsage = new Usage(contentBlockReference.get().usage.inputTokens(),
messageDeltaEvent.usage().outputTokens());
contentBlockReference.get().withUsage(totalUsage);
}
}
else if (event.type().equals(EventType.MESSAGE_STOP)) {
// pass through
// pass through as‑is
}
else {
// Any other event types that should propagate upwards without content
contentBlockReference.get().withType(event.type().name()).withContent(List.of());
}

return contentBlockReference.get().build();
}

/**
* Builder for {@link ChatCompletionResponse}. Used internally by {@link StreamHelper}
* to aggregate stream events.
*/
public static class ChatCompletionResponseBuilder {

private String type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ void functionCallTest() {
assertThat(generation.getOutput().getText()).contains("30", "10", "15");
assertThat(response.getMetadata()).isNotNull();
assertThat(response.getMetadata().getUsage()).isNotNull();
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(1800);
assertThat(response.getMetadata().getUsage().getTotalTokens()).isLessThan(4000).isGreaterThan(100);
}

@Test
Expand Down Expand Up @@ -415,6 +415,38 @@ else if (message.getMetadata().containsKey("data")) { // redacted thinking
}
}

@Test
void thinkingWithStreamingTest() {
UserMessage userMessage = new UserMessage(
"Are there an infinite number of prime numbers such that n mod 4 == 3?");

var promptOptions = AnthropicChatOptions.builder()
.model(AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getName())
.temperature(1.0) // Temperature should be set to 1 when thinking is enabled
.maxTokens(8192)
.thinking(AnthropicApi.ThinkingType.ENABLED, 2048) // Must be ≥1024 && <
// max_tokens
.build();

Flux<ChatResponse> responseFlux = this.streamingChatModel
.stream(new Prompt(List.of(userMessage), promptOptions));

String content = responseFlux.collectList()
.block()
.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getText)
.filter(text -> text != null && !text.isBlank())
.collect(Collectors.joining());

logger.info("Response: {}", content);

assertThat(content).isNotBlank();
assertThat(content).contains("prime numbers");
}

@Test
void testToolUseContentBlock() {
UserMessage userMessage = new UserMessage(
Expand Down
Loading