Skip to content

Commit 33ac74f

Browse files
committed
fix: Gemini thoughts not correctly accumulated when streaming enabled
1 parent 391e049 commit 33ac74f

File tree

4 files changed

+62
-129
lines changed

4 files changed

+62
-129
lines changed

core/src/main/java/com/google/adk/models/Gemini.java

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import java.util.ArrayList;
3636
import java.util.List;
3737
import java.util.Objects;
38+
import java.util.Optional;
3839
import java.util.concurrent.CompletableFuture;
3940
import org.slf4j.Logger;
4041
import org.slf4j.LoggerFactory;
@@ -236,6 +237,7 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre
236237

237238
static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentResponse> rawResponses) {
238239
final StringBuilder accumulatedText = new StringBuilder();
240+
final StringBuilder accumulatedThoughtText = new StringBuilder();
239241
// Array to bypass final local variable reassignment in lambda.
240242
final GenerateContentResponse[] lastRawResponseHolder = {null};
241243
return rawResponses
@@ -246,15 +248,26 @@ static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentRespons
246248

247249
List<LlmResponse> responsesToEmit = new ArrayList<>();
248250
LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse);
249-
String currentTextChunk =
250-
GeminiUtil.getTextFromLlmResponse(currentProcessedLlmResponse);
251+
Optional<Part> part = GeminiUtil.getPart0FromLlmResponse(currentProcessedLlmResponse);
252+
String currentTextChunk = part.flatMap(Part::text).orElse("");
251253

252254
if (!currentTextChunk.isEmpty()) {
253-
accumulatedText.append(currentTextChunk);
255+
if (part.get().thought().orElse(false)) {
256+
accumulatedThoughtText.append(currentTextChunk);
257+
} else {
258+
accumulatedText.append(currentTextChunk);
259+
}
254260
LlmResponse partialResponse =
255261
currentProcessedLlmResponse.toBuilder().partial(true).build();
256262
responsesToEmit.add(partialResponse);
257263
} else {
264+
if (accumulatedThoughtText.length() > 0
265+
&& GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
266+
LlmResponse aggregatedTextResponse =
267+
thinkingResponseFromText(accumulatedThoughtText.toString());
268+
responsesToEmit.add(aggregatedTextResponse);
269+
accumulatedThoughtText.setLength(0);
270+
}
258271
if (accumulatedText.length() > 0
259272
&& GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
260273
LlmResponse aggregatedTextResponse = responseFromText(accumulatedText.toString());
@@ -296,6 +309,16 @@ private static LlmResponse responseFromText(String accumulatedText) {
296309
.build();
297310
}
298311

312+
private static LlmResponse thinkingResponseFromText(String accumulatedThoughtText) {
313+
return LlmResponse.builder()
314+
.content(
315+
Content.builder()
316+
.role("model")
317+
.parts(Part.fromText(accumulatedThoughtText).toBuilder().thought(true).build())
318+
.build())
319+
.build();
320+
}
321+
299322
@Override
300323
public BaseLlmConnection connect(LlmRequest llmRequest) {
301324
if (!apiClient.vertexAI()) {

core/src/main/java/com/google/adk/models/GeminiUtil.java

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.genai.types.FileData;
2626
import com.google.genai.types.Part;
2727
import java.util.List;
28+
import java.util.Optional;
2829
import java.util.stream.Stream;
2930

3031
/** Request / Response utilities for {@link Gemini}. */
@@ -40,7 +41,7 @@ private GeminiUtil() {}
4041
* Prepares an {@link LlmRequest} for the GenerateContent API.
4142
*
4243
* <p>This method can optionally sanitize the request and ensures that the last content part is
43-
* from the user to prompt a model response. It also strips out any parts marked as "thoughts".
44+
* from the user to prompt a model response.
4445
*
4546
* @param llmRequest The original {@link LlmRequest}.
4647
* @param sanitize Whether to sanitize the request to be compatible with the Gemini API backend.
@@ -52,8 +53,7 @@ public static LlmRequest prepareGenenerateContentRequest(
5253
llmRequest = sanitizeRequestForGeminiApi(llmRequest);
5354
}
5455
List<Content> contents = ensureModelResponse(llmRequest.contents());
55-
List<Content> finalContents = stripThoughts(contents);
56-
return llmRequest.toBuilder().contents(finalContents).build();
56+
return llmRequest.toBuilder().contents(contents).build();
5757
}
5858

5959
/**
@@ -140,19 +140,17 @@ static List<Content> ensureModelResponse(List<Content> contents) {
140140
}
141141

142142
/**
143-
* Extracts text content from the first part of an LlmResponse, if available.
143+
* Extracts the first part of an LlmResponse, if available.
144144
*
145-
* @param llmResponse The LlmResponse to extract text from.
146-
* @return The text content, or an empty string if not found.
145+
* @param llmResponse The LlmResponse to extract the first part from.
146+
* @return The first part, or an empty optional if not found.
147147
*/
148-
public static String getTextFromLlmResponse(LlmResponse llmResponse) {
148+
public static Optional<Part> getPart0FromLlmResponse(LlmResponse llmResponse) {
149149
return llmResponse
150150
.content()
151151
.flatMap(Content::parts)
152152
.filter(parts -> !parts.isEmpty())
153-
.map(parts -> parts.get(0))
154-
.flatMap(Part::text)
155-
.orElse("");
153+
.map(parts -> parts.get(0));
156154
}
157155

158156
/**
@@ -175,19 +173,4 @@ public static boolean shouldEmitAccumulatedText(LlmResponse currentLlmResponse)
175173
.flatMap(Part::inlineData)
176174
.isEmpty();
177175
}
178-
179-
/** Removes any `Part` that contains only a `thought` from the content list. */
180-
public static List<Content> stripThoughts(List<Content> originalContents) {
181-
return originalContents.stream()
182-
.map(
183-
content -> {
184-
ImmutableList<Part> nonThoughtParts =
185-
content.parts().orElse(ImmutableList.of()).stream()
186-
// Keep if thought is not present OR if thought is present but false
187-
.filter(part -> part.thought().map(isThought -> !isThought).orElse(true))
188-
.collect(toImmutableList());
189-
return content.toBuilder().parts(nonThoughtParts).build();
190-
})
191-
.collect(toImmutableList());
192-
}
193176
}

core/src/test/java/com/google/adk/models/GeminiTest.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,17 @@ private void assertLlmResponses(
139139
private static Predicate<LlmResponse> isPartialTextResponse(String expectedText) {
140140
return response -> {
141141
assertThat(response.partial()).hasValue(true);
142-
assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEqualTo(expectedText);
142+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
143+
.isEqualTo(expectedText);
143144
return true;
144145
};
145146
}
146147

147148
private static Predicate<LlmResponse> isFinalTextResponse(String expectedText) {
148149
return response -> {
149150
assertThat(response.partial()).isEmpty();
150-
assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEqualTo(expectedText);
151+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
152+
.isEqualTo(expectedText);
151153
return true;
152154
};
153155
}
@@ -162,7 +164,8 @@ private static Predicate<LlmResponse> isFunctionCallResponse() {
162164
private static Predicate<LlmResponse> isEmptyResponse() {
163165
return response -> {
164166
assertThat(response.partial()).isEmpty();
165-
assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEmpty();
167+
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
168+
.isEmpty();
166169
return true;
167170
};
168171
}

core/src/test/java/com/google/adk/models/GeminiUtilTest.java

Lines changed: 22 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -39,125 +39,49 @@ public final class GeminiUtilTest {
3939
Content.fromParts(Part.fromText(GeminiUtil.CONTINUE_OUTPUT_MESSAGE));
4040

4141
@Test
42-
public void stripThoughts_emptyList_returnsEmptyList() {
43-
assertThat(GeminiUtil.stripThoughts(ImmutableList.of())).isEmpty();
44-
}
45-
46-
@Test
47-
public void stripThoughts_contentWithNoParts_returnsContentWithNoParts() {
48-
Content content = Content.builder().build();
49-
Content expected = toContent();
50-
51-
List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));
52-
53-
assertThat(result).containsExactly(expected);
54-
}
55-
56-
@Test
57-
public void stripThoughts_partsWithoutThought_returnsAllParts() {
58-
Part part1 = createTextPart("Hello");
59-
Part part2 = createTextPart("World");
60-
Content content = toContent(part1, part2);
61-
62-
List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));
63-
64-
assertThat(result.get(0).parts().get()).containsExactly(part1, part2).inOrder();
65-
}
66-
67-
@Test
68-
public void stripThoughts_partsWithThoughtFalse_returnsAllParts() {
69-
Part part1 = createThoughtPart("Regular text", false);
70-
Part part2 = createTextPart("Another text");
71-
Content content = toContent(part1, part2);
72-
73-
List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));
74-
75-
assertThat(result.get(0).parts().get()).containsExactly(part1, part2).inOrder();
76-
}
77-
78-
@Test
79-
public void stripThoughts_partsWithThoughtTrue_stripsThoughtParts() {
80-
Part part1 = createTextPart("Visible text");
81-
Part part2 = createThoughtPart("Internal thought", true);
82-
Part part3 = createTextPart("More visible text");
83-
Content content = toContent(part1, part2, part3);
84-
85-
List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));
86-
87-
assertThat(result.get(0).parts().get()).containsExactly(part1, part3).inOrder();
88-
}
89-
90-
@Test
91-
public void stripThoughts_mixedParts_stripsOnlyThoughtTrue() {
92-
Part part1 = createTextPart("Text 1");
93-
Part part2 = createThoughtPart("Thought 1", true);
94-
Part part3 = createTextPart("Text 2");
95-
Part part4 = createThoughtPart("Not a thought", false);
96-
Part part5 = createThoughtPart("Thought 2", true);
97-
Content content = toContent(part1, part2, part3, part4, part5);
98-
99-
List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));
100-
101-
assertThat(result.get(0).parts().get()).containsExactly(part1, part3, part4).inOrder();
102-
}
103-
104-
@Test
105-
public void stripThoughts_multipleContents_stripsThoughtsFromEach() {
106-
Part partA1 = createTextPart("A1");
107-
Part partA2 = createThoughtPart("A2 Thought", true);
108-
Content contentA = toContent(partA1, partA2);
109-
110-
Part partB1 = createThoughtPart("B1 Thought", true);
111-
Part partB2 = createTextPart("B2");
112-
Part partB3 = createThoughtPart("B3 Not Thought", false);
113-
Content contentB = toContent(partB1, partB2, partB3);
114-
115-
List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(contentA, contentB));
42+
public void getPart0FromLlmResponse_noContent_returnsEmpty() {
43+
LlmResponse llmResponse = LlmResponse.builder().build();
11644

117-
assertThat(result).hasSize(2);
118-
assertThat(result.get(0).parts().get()).containsExactly(partA1);
119-
assertThat(result.get(1).parts().get()).containsExactly(partB2, partB3).inOrder();
45+
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty();
12046
}
12147

12248
@Test
123-
public void getTextFromLlmResponse_noContent_returnsEmptyString() {
124-
LlmResponse llmResponse = LlmResponse.builder().build();
49+
public void getPart0FromLlmResponse_contentWithNoParts_returnsEmpty() {
50+
LlmResponse llmResponse = toResponse(Content.builder().build());
12551

126-
assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty();
52+
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty();
12753
}
12854

12955
@Test
130-
public void getTextFromLlmResponse_contentWithNoParts_returnsEmptyString() {
131-
LlmResponse llmResponse = toResponse(Content.builder().build());
56+
public void getPart0FromLlmResponse_contentWithEmptyPartsList_returnsEmpty() {
57+
LlmResponse llmResponse = toResponse(toContent());
13258

133-
assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty();
59+
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty();
13460
}
13561

13662
@Test
137-
public void getTextFromLlmResponse_firstPartHasNoText_returnsEmptyString() {
138-
Part part1 = Part.builder().inlineData(Blob.builder().mimeType("image/png").build()).build();
139-
LlmResponse llmResponse = toResponse(part1);
63+
public void getPart0FromLlmResponse_contentWithSinglePart_returnsFirstPart() {
64+
Part expectedPart = createTextPart("Hello world");
65+
LlmResponse llmResponse = toResponse(expectedPart);
14066

141-
assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty();
67+
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(expectedPart);
14268
}
14369

14470
@Test
145-
public void getTextFromLlmResponse_firstPartHasText_returnsText() {
146-
String expectedText = "The quick brown fox.";
147-
Part part1 = createTextPart(expectedText);
148-
LlmResponse llmResponse = toResponse(part1);
71+
public void getPart0FromLlmResponse_contentWithMultipleParts_returnsFirstPart() {
72+
Part firstPart = createTextPart("First part");
73+
Part secondPart = createTextPart("Second part");
74+
LlmResponse llmResponse = toResponse(firstPart, secondPart);
14975

150-
assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEqualTo(expectedText);
76+
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(firstPart);
15177
}
15278

15379
@Test
154-
public void getTextFromLlmResponse_multipleParts_returnsTextFromFirstPartOnly() {
155-
String expectedText = "First part text.";
156-
Part part1 = createTextPart(expectedText);
157-
Part part2 = createTextPart("Second part text.");
158-
LlmResponse llmResponse = toResponse(part1, part2);
80+
public void getPart0FromLlmResponse_contentWithThoughtPart_returnsFirstPart() {
81+
Part expectedPart = createThoughtPart("I need to think about this", true);
82+
LlmResponse llmResponse = toResponse(expectedPart);
15983

160-
assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEqualTo(expectedText);
84+
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(expectedPart);
16185
}
16286

16387
@Test

0 commit comments

Comments
 (0)