Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions core/src/main/java/com/google/adk/models/Gemini.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -236,6 +237,7 @@ public Flowable<LlmResponse> generateContent(LlmRequest llmRequest, boolean stre

static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentResponse> rawResponses) {
final StringBuilder accumulatedText = new StringBuilder();
final StringBuilder accumulatedThoughtText = new StringBuilder();
// Array to bypass final local variable reassignment in lambda.
final GenerateContentResponse[] lastRawResponseHolder = {null};
return rawResponses
Expand All @@ -246,15 +248,26 @@ static Flowable<LlmResponse> processRawResponses(Flowable<GenerateContentRespons

List<LlmResponse> responsesToEmit = new ArrayList<>();
LlmResponse currentProcessedLlmResponse = LlmResponse.create(rawResponse);
String currentTextChunk =
GeminiUtil.getTextFromLlmResponse(currentProcessedLlmResponse);
Optional<Part> part = GeminiUtil.getPart0FromLlmResponse(currentProcessedLlmResponse);
String currentTextChunk = part.flatMap(Part::text).orElse("");

if (!currentTextChunk.isEmpty()) {
accumulatedText.append(currentTextChunk);
if (part.get().thought().orElse(false)) {
accumulatedThoughtText.append(currentTextChunk);
} else {
accumulatedText.append(currentTextChunk);
}
LlmResponse partialResponse =
currentProcessedLlmResponse.toBuilder().partial(true).build();
responsesToEmit.add(partialResponse);
} else {
if (accumulatedThoughtText.length() > 0
&& GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
LlmResponse aggregatedThoughtResponse =
thinkingResponseFromText(accumulatedThoughtText.toString());
responsesToEmit.add(aggregatedThoughtResponse);
accumulatedThoughtText.setLength(0);
}
if (accumulatedText.length() > 0
&& GeminiUtil.shouldEmitAccumulatedText(currentProcessedLlmResponse)) {
LlmResponse aggregatedTextResponse = responseFromText(accumulatedText.toString());
Expand Down Expand Up @@ -296,6 +309,16 @@ private static LlmResponse responseFromText(String accumulatedText) {
.build();
}

private static LlmResponse thinkingResponseFromText(String accumulatedThoughtText) {
return LlmResponse.builder()
.content(
Content.builder()
.role("model")
.parts(Part.fromText(accumulatedThoughtText).toBuilder().thought(true).build())
.build())
.build();
}

@Override
public BaseLlmConnection connect(LlmRequest llmRequest) {
if (!apiClient.vertexAI()) {
Expand Down
33 changes: 8 additions & 25 deletions core/src/main/java/com/google/adk/models/GeminiUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.google.genai.types.FileData;
import com.google.genai.types.Part;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

/** Request / Response utilities for {@link Gemini}. */
Expand All @@ -41,7 +42,7 @@ private GeminiUtil() {}
* Prepares an {@link LlmRequest} for the GenerateContent API.
*
* <p>This method can optionally sanitize the request and ensures that the last content part is
* from the user to prompt a model response. It also strips out any parts marked as "thoughts".
* from the user to prompt a model response.
*
* @param llmRequest The original {@link LlmRequest}.
* @param sanitize Whether to sanitize the request to be compatible with the Gemini API backend.
Expand All @@ -53,8 +54,7 @@ public static LlmRequest prepareGenenerateContentRequest(
llmRequest = sanitizeRequestForGeminiApi(llmRequest);
}
List<Content> contents = ensureModelResponse(llmRequest.contents());
List<Content> finalContents = stripThoughts(contents);
return llmRequest.toBuilder().contents(finalContents).build();
return llmRequest.toBuilder().contents(contents).build();
}

/**
Expand Down Expand Up @@ -142,19 +142,17 @@ static List<Content> ensureModelResponse(List<Content> contents) {
}

/**
* Extracts text content from the first part of an LlmResponse, if available.
* Extracts the first part of an LlmResponse, if available.
*
* @param llmResponse The LlmResponse to extract text from.
* @return The text content, or an empty string if not found.
* @param llmResponse The LlmResponse to extract the first part from.
* @return The first part, or an empty optional if not found.
*/
public static String getTextFromLlmResponse(LlmResponse llmResponse) {
public static Optional<Part> getPart0FromLlmResponse(LlmResponse llmResponse) {
return llmResponse
.content()
.flatMap(Content::parts)
.filter(parts -> !parts.isEmpty())
.map(parts -> parts.get(0))
.flatMap(Part::text)
.orElse("");
.map(parts -> parts.get(0));
}

/**
Expand All @@ -177,19 +175,4 @@ public static boolean shouldEmitAccumulatedText(LlmResponse currentLlmResponse)
.flatMap(Part::inlineData)
.isEmpty();
}

/** Removes any `Part` that contains only a `thought` from the content list. */
public static List<Content> stripThoughts(List<Content> originalContents) {
return originalContents.stream()
.map(
content -> {
ImmutableList<Part> nonThoughtParts =
content.parts().orElse(ImmutableList.of()).stream()
// Keep if thought is not present OR if thought is present but false
.filter(part -> part.thought().map(isThought -> !isThought).orElse(true))
.collect(toImmutableList());
return content.toBuilder().parts(nonThoughtParts).build();
})
.collect(toImmutableList());
}
}
9 changes: 6 additions & 3 deletions core/src/test/java/com/google/adk/models/GeminiTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,17 @@ private void assertLlmResponses(
private static Predicate<LlmResponse> isPartialTextResponse(String expectedText) {
return response -> {
assertThat(response.partial()).hasValue(true);
assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEqualTo(expectedText);
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEqualTo(expectedText);
return true;
};
}

private static Predicate<LlmResponse> isFinalTextResponse(String expectedText) {
return response -> {
assertThat(response.partial()).isEmpty();
assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEqualTo(expectedText);
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEqualTo(expectedText);
return true;
};
}
Expand All @@ -162,7 +164,8 @@ private static Predicate<LlmResponse> isFunctionCallResponse() {
private static Predicate<LlmResponse> isEmptyResponse() {
return response -> {
assertThat(response.partial()).isEmpty();
assertThat(GeminiUtil.getTextFromLlmResponse(response)).isEmpty();
assertThat(GeminiUtil.getPart0FromLlmResponse(response).flatMap(Part::text).orElse(""))
.isEmpty();
return true;
};
}
Expand Down
120 changes: 22 additions & 98 deletions core/src/test/java/com/google/adk/models/GeminiUtilTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,125 +39,49 @@ public final class GeminiUtilTest {
Content.fromParts(Part.fromText(GeminiUtil.CONTINUE_OUTPUT_MESSAGE));

@Test
public void stripThoughts_emptyList_returnsEmptyList() {
assertThat(GeminiUtil.stripThoughts(ImmutableList.of())).isEmpty();
}

@Test
public void stripThoughts_contentWithNoParts_returnsContentWithNoParts() {
Content content = Content.builder().build();
Content expected = toContent();

List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));

assertThat(result).containsExactly(expected);
}

@Test
public void stripThoughts_partsWithoutThought_returnsAllParts() {
Part part1 = createTextPart("Hello");
Part part2 = createTextPart("World");
Content content = toContent(part1, part2);

List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));

assertThat(result.get(0).parts().get()).containsExactly(part1, part2).inOrder();
}

@Test
public void stripThoughts_partsWithThoughtFalse_returnsAllParts() {
Part part1 = createThoughtPart("Regular text", false);
Part part2 = createTextPart("Another text");
Content content = toContent(part1, part2);

List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));

assertThat(result.get(0).parts().get()).containsExactly(part1, part2).inOrder();
}

@Test
public void stripThoughts_partsWithThoughtTrue_stripsThoughtParts() {
Part part1 = createTextPart("Visible text");
Part part2 = createThoughtPart("Internal thought", true);
Part part3 = createTextPart("More visible text");
Content content = toContent(part1, part2, part3);

List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));

assertThat(result.get(0).parts().get()).containsExactly(part1, part3).inOrder();
}

@Test
public void stripThoughts_mixedParts_stripsOnlyThoughtTrue() {
Part part1 = createTextPart("Text 1");
Part part2 = createThoughtPart("Thought 1", true);
Part part3 = createTextPart("Text 2");
Part part4 = createThoughtPart("Not a thought", false);
Part part5 = createThoughtPart("Thought 2", true);
Content content = toContent(part1, part2, part3, part4, part5);

List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(content));

assertThat(result.get(0).parts().get()).containsExactly(part1, part3, part4).inOrder();
}

@Test
public void stripThoughts_multipleContents_stripsThoughtsFromEach() {
Part partA1 = createTextPart("A1");
Part partA2 = createThoughtPart("A2 Thought", true);
Content contentA = toContent(partA1, partA2);

Part partB1 = createThoughtPart("B1 Thought", true);
Part partB2 = createTextPart("B2");
Part partB3 = createThoughtPart("B3 Not Thought", false);
Content contentB = toContent(partB1, partB2, partB3);

List<Content> result = GeminiUtil.stripThoughts(ImmutableList.of(contentA, contentB));
public void getPart0FromLlmResponse_noContent_returnsEmpty() {
LlmResponse llmResponse = LlmResponse.builder().build();

assertThat(result).hasSize(2);
assertThat(result.get(0).parts().get()).containsExactly(partA1);
assertThat(result.get(1).parts().get()).containsExactly(partB2, partB3).inOrder();
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty();
}

@Test
public void getTextFromLlmResponse_noContent_returnsEmptyString() {
LlmResponse llmResponse = LlmResponse.builder().build();
public void getPart0FromLlmResponse_contentWithNoParts_returnsEmpty() {
LlmResponse llmResponse = toResponse(Content.builder().build());

assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty();
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty();
}

@Test
public void getTextFromLlmResponse_contentWithNoParts_returnsEmptyString() {
LlmResponse llmResponse = toResponse(Content.builder().build());
public void getPart0FromLlmResponse_contentWithEmptyPartsList_returnsEmpty() {
LlmResponse llmResponse = toResponse(toContent());

assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty();
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).isEmpty();
}

@Test
public void getTextFromLlmResponse_firstPartHasNoText_returnsEmptyString() {
Part part1 = Part.builder().inlineData(Blob.builder().mimeType("image/png").build()).build();
LlmResponse llmResponse = toResponse(part1);
public void getPart0FromLlmResponse_contentWithSinglePart_returnsFirstPart() {
Part expectedPart = createTextPart("Hello world");
LlmResponse llmResponse = toResponse(expectedPart);

assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEmpty();
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(expectedPart);
}

@Test
public void getTextFromLlmResponse_firstPartHasText_returnsText() {
String expectedText = "The quick brown fox.";
Part part1 = createTextPart(expectedText);
LlmResponse llmResponse = toResponse(part1);
public void getPart0FromLlmResponse_contentWithMultipleParts_returnsFirstPart() {
Part firstPart = createTextPart("First part");
Part secondPart = createTextPart("Second part");
LlmResponse llmResponse = toResponse(firstPart, secondPart);

assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEqualTo(expectedText);
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(firstPart);
}

@Test
public void getTextFromLlmResponse_multipleParts_returnsTextFromFirstPartOnly() {
String expectedText = "First part text.";
Part part1 = createTextPart(expectedText);
Part part2 = createTextPart("Second part text.");
LlmResponse llmResponse = toResponse(part1, part2);
public void getPart0FromLlmResponse_contentWithThoughtPart_returnsFirstPart() {
Part expectedPart = createThoughtPart("I need to think about this", true);
LlmResponse llmResponse = toResponse(expectedPart);

assertThat(GeminiUtil.getTextFromLlmResponse(llmResponse)).isEqualTo(expectedText);
assertThat(GeminiUtil.getPart0FromLlmResponse(llmResponse)).hasValue(expectedPart);
}

@Test
Expand Down