Skip to content

Commit a5d184d

Browse files
committed
Add IT with Docker Model Runner
Docker Desktop 4.4.0 has released a Docker Model Runner, which is OpenAI compatible. Signed-off-by: Eddú Meléndez <[email protected]>
1 parent 8f20aab commit a5d184d

File tree

2 files changed

+364
-0
lines changed

2 files changed

+364
-0
lines changed

models/spring-ai-openai/pom.xml

+6
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,12 @@
132132
<scope>test</scope>
133133
</dependency>
134134

135+
<dependency>
136+
<groupId>io.rest-assured</groupId>
137+
<artifactId>rest-assured</artifactId>
138+
<scope>test</scope>
139+
</dependency>
140+
135141

136142
</dependencies>
137143

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.openai.chat.proxy;
18+
19+
import java.io.IOException;
20+
import java.util.ArrayList;
21+
import java.util.Arrays;
22+
import java.util.List;
23+
import java.util.Map;
24+
import java.util.stream.Collectors;
25+
26+
import io.restassured.RestAssured;
27+
import org.junit.jupiter.api.BeforeAll;
28+
import org.junit.jupiter.api.Disabled;
29+
import org.junit.jupiter.api.Test;
30+
import org.slf4j.Logger;
31+
import org.slf4j.LoggerFactory;
32+
import org.testcontainers.containers.SocatContainer;
33+
import org.testcontainers.junit.jupiter.Container;
34+
import org.testcontainers.junit.jupiter.Testcontainers;
35+
import reactor.core.publisher.Flux;
36+
37+
import org.springframework.ai.chat.client.ChatClient;
38+
import org.springframework.ai.chat.messages.AssistantMessage;
39+
import org.springframework.ai.chat.messages.Message;
40+
import org.springframework.ai.chat.messages.UserMessage;
41+
import org.springframework.ai.chat.model.ChatResponse;
42+
import org.springframework.ai.chat.model.Generation;
43+
import org.springframework.ai.chat.prompt.Prompt;
44+
import org.springframework.ai.chat.prompt.PromptTemplate;
45+
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
46+
import org.springframework.ai.converter.BeanOutputConverter;
47+
import org.springframework.ai.converter.ListOutputConverter;
48+
import org.springframework.ai.converter.MapOutputConverter;
49+
import org.springframework.ai.model.function.FunctionCallback;
50+
import org.springframework.ai.openai.OpenAiChatModel;
51+
import org.springframework.ai.openai.OpenAiChatOptions;
52+
import org.springframework.ai.openai.api.OpenAiApi;
53+
import org.springframework.ai.openai.api.tool.MockWeatherService;
54+
import org.springframework.ai.openai.chat.ActorsFilms;
55+
import org.springframework.beans.factory.annotation.Autowired;
56+
import org.springframework.beans.factory.annotation.Value;
57+
import org.springframework.boot.SpringBootConfiguration;
58+
import org.springframework.boot.test.context.SpringBootTest;
59+
import org.springframework.context.annotation.Bean;
60+
import org.springframework.core.convert.support.DefaultConversionService;
61+
import org.springframework.core.io.Resource;
62+
63+
import static org.assertj.core.api.Assertions.assertThat;
64+
65+
/**
66+
* @author Christian Tzolov
67+
* @author Eddú Meléndez
68+
* @since 1.0.0
69+
*/
70+
@Testcontainers
71+
@SpringBootTest(classes = DockerModelRunnerWithOpenAiChatModelIT.Config.class)
72+
@Disabled("Requires Docker Model Runner enabled. See https://docs.docker.com/desktop/features/model-runner/")
73+
class DockerModelRunnerWithOpenAiChatModelIT {
74+
75+
private static final Logger logger = LoggerFactory.getLogger(DockerModelRunnerWithOpenAiChatModelIT.class);
76+
77+
private static final String DEFAULT_MODEL = "ai/gemma3:4B-F16";
78+
79+
@Container
80+
private static final SocatContainer socat = new SocatContainer().withTarget(80, "model-runner.docker.internal");
81+
82+
@Value("classpath:/prompts/system-message.st")
83+
private Resource systemResource;
84+
85+
@Autowired
86+
private OpenAiChatModel chatModel;
87+
88+
@BeforeAll
89+
public static void beforeAll() throws IOException, InterruptedException {
90+
logger.info("Start pulling the '" + DEFAULT_MODEL + "' generative ... would take several minutes ...");
91+
92+
String baseUrl = "http://%s:%d".formatted(socat.getHost(), socat.getMappedPort(80));
93+
94+
RestAssured.given().baseUri(baseUrl).body("""
95+
{
96+
"from": "%s"
97+
}
98+
""".formatted(DEFAULT_MODEL)).post("/models/create").prettyPeek().then().statusCode(200);
99+
100+
logger.info(DEFAULT_MODEL + " pulling competed!");
101+
}
102+
103+
@Test
104+
void roleTest() {
105+
UserMessage userMessage = new UserMessage(
106+
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
107+
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
108+
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
109+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
110+
ChatResponse response = this.chatModel.call(prompt);
111+
assertThat(response.getResults()).hasSize(1);
112+
assertThat(response.getResults().get(0).getOutput().getText()).contains("Blackbeard");
113+
}
114+
115+
@Test
116+
void streamRoleTest() {
117+
UserMessage userMessage = new UserMessage(
118+
"Tell me about 3 famous pirates from the Golden Age of Piracy and what they did.");
119+
SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource);
120+
Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate"));
121+
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
122+
Flux<ChatResponse> flux = this.chatModel.stream(prompt);
123+
124+
List<ChatResponse> responses = flux.collectList().block();
125+
assertThat(responses.size()).isGreaterThan(1);
126+
127+
String stitchedResponseContent = responses.stream()
128+
.map(ChatResponse::getResults)
129+
.flatMap(List::stream)
130+
.map(Generation::getOutput)
131+
.map(AssistantMessage::getText)
132+
.collect(Collectors.joining());
133+
134+
assertThat(stitchedResponseContent).contains("Blackbeard");
135+
}
136+
137+
@Test
138+
void streamingWithTokenUsage() {
139+
var promptOptions = OpenAiChatOptions.builder().streamUsage(true).seed(1).build();
140+
141+
var prompt = new Prompt("List two colors of the Polish flag. Be brief.", promptOptions);
142+
143+
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
144+
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();
145+
146+
assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
147+
assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0);
148+
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);
149+
150+
assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
151+
assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens());
152+
assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens());
153+
154+
}
155+
156+
@Test
157+
void listOutputConverter() {
158+
DefaultConversionService conversionService = new DefaultConversionService();
159+
ListOutputConverter outputConverter = new ListOutputConverter(conversionService);
160+
161+
String format = outputConverter.getFormat();
162+
String template = """
163+
List five {subject}
164+
{format}
165+
""";
166+
PromptTemplate promptTemplate = new PromptTemplate(template,
167+
Map.of("subject", "ice cream flavors", "format", format));
168+
Prompt prompt = new Prompt(promptTemplate.createMessage());
169+
Generation generation = this.chatModel.call(prompt).getResult();
170+
171+
List<String> list = outputConverter.convert(generation.getOutput().getText());
172+
assertThat(list).hasSize(5);
173+
174+
}
175+
176+
@Test
177+
void mapOutputConverter() {
178+
MapOutputConverter outputConverter = new MapOutputConverter();
179+
180+
String format = outputConverter.getFormat();
181+
String template = """
182+
Provide me a List of {subject}
183+
{format}
184+
""";
185+
PromptTemplate promptTemplate = new PromptTemplate(template,
186+
Map.of("subject", "numbers from 1 to 9 under they key name 'numbers'", "format", format));
187+
Prompt prompt = new Prompt(promptTemplate.createMessage());
188+
Generation generation = this.chatModel.call(prompt).getResult();
189+
190+
Map<String, Object> result = outputConverter.convert(generation.getOutput().getText());
191+
assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9));
192+
193+
}
194+
195+
@Test
196+
void beanOutputConverter() {
197+
198+
BeanOutputConverter<ActorsFilms> outputConverter = new BeanOutputConverter<>(ActorsFilms.class);
199+
200+
String format = outputConverter.getFormat();
201+
String template = """
202+
Generate the filmography for a random actor.
203+
{format}
204+
""";
205+
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
206+
Prompt prompt = new Prompt(promptTemplate.createMessage());
207+
Generation generation = this.chatModel.call(prompt).getResult();
208+
209+
ActorsFilms actorsFilms = outputConverter.convert(generation.getOutput().getText());
210+
assertThat(actorsFilms.getActor()).isNotEmpty();
211+
}
212+
213+
@Test
214+
void beanOutputConverterRecords() {
215+
216+
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
217+
218+
String format = outputConverter.getFormat();
219+
String template = """
220+
Generate the filmography of 5 movies for Tom Hanks.
221+
{format}
222+
""";
223+
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
224+
Prompt prompt = new Prompt(promptTemplate.createMessage());
225+
Generation generation = this.chatModel.call(prompt).getResult();
226+
227+
ActorsFilmsRecord actorsFilms = outputConverter.convert(generation.getOutput().getText());
228+
logger.info("" + actorsFilms);
229+
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
230+
assertThat(actorsFilms.movies()).hasSize(5);
231+
}
232+
233+
@Test
234+
void beanStreamOutputConverterRecords() {
235+
236+
BeanOutputConverter<ActorsFilmsRecord> outputConverter = new BeanOutputConverter<>(ActorsFilmsRecord.class);
237+
238+
String format = outputConverter.getFormat();
239+
String template = """
240+
Generate the filmography of 5 movies for Tom Hanks.
241+
{format}
242+
""";
243+
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
244+
Prompt prompt = new Prompt(promptTemplate.createMessage());
245+
246+
String generationTextFromStream = this.chatModel.stream(prompt)
247+
.collectList()
248+
.block()
249+
.stream()
250+
.map(ChatResponse::getResults)
251+
.flatMap(List::stream)
252+
.map(Generation::getOutput)
253+
.map(AssistantMessage::getText)
254+
.filter(c -> c != null)
255+
.collect(Collectors.joining());
256+
257+
ActorsFilmsRecord actorsFilms = outputConverter.convert(generationTextFromStream);
258+
logger.info("" + actorsFilms);
259+
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
260+
assertThat(actorsFilms.movies()).hasSize(5);
261+
}
262+
263+
@Test
264+
void functionCallTest() {
265+
266+
UserMessage userMessage = new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?");
267+
268+
List<Message> messages = new ArrayList<>(List.of(userMessage));
269+
270+
var promptOptions = OpenAiChatOptions.builder()
271+
.functionCallbacks(List.of(FunctionCallback.builder()
272+
.function("getCurrentWeather", new MockWeatherService())
273+
.description("Get the weather in location")
274+
.inputType(MockWeatherService.Request.class)
275+
.build()))
276+
.build();
277+
278+
ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions));
279+
280+
logger.info("Response: {}", response);
281+
282+
assertThat(response.getResult().getOutput().getText()).contains("30", "10", "15");
283+
}
284+
285+
@Test
286+
@Disabled("stream function call not supported yet")
287+
void streamFunctionCallTest() {
288+
289+
UserMessage userMessage = new UserMessage(
290+
"What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.");
291+
292+
List<Message> messages = new ArrayList<>(List.of(userMessage));
293+
294+
var promptOptions = OpenAiChatOptions.builder()
295+
.functionCallbacks(List.of(FunctionCallback.builder()
296+
.function("getCurrentWeather", new MockWeatherService())
297+
.description("Get the weather in location")
298+
.inputType(MockWeatherService.Request.class)
299+
.build()))
300+
.build();
301+
302+
Flux<ChatResponse> response = this.chatModel.stream(new Prompt(messages, promptOptions));
303+
304+
String content = response.collectList()
305+
.block()
306+
.stream()
307+
.map(ChatResponse::getResults)
308+
.flatMap(List::stream)
309+
.map(Generation::getOutput)
310+
.map(AssistantMessage::getText)
311+
.collect(Collectors.joining());
312+
logger.info("Response: {}", content);
313+
314+
assertThat(content).contains("30", "10", "15");
315+
}
316+
317+
@Test
318+
void validateCallResponseMetadata() {
319+
// @formatter:off
320+
ChatResponse response = ChatClient.create(this.chatModel).prompt()
321+
.options(OpenAiChatOptions.builder().model(DEFAULT_MODEL).build())
322+
.user("Tell me about 3 famous pirates from the Golden Age of Piracy and what they did")
323+
.call()
324+
.chatResponse();
325+
// @formatter:on
326+
327+
logger.info(response.toString());
328+
assertThat(response.getMetadata().getId()).isNotEmpty();
329+
assertThat(response.getMetadata().getModel()).containsIgnoringCase(DEFAULT_MODEL);
330+
assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive();
331+
assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive();
332+
assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive();
333+
}
334+
335+
record ActorsFilmsRecord(String actor, List<String> movies) {
336+
337+
}
338+
339+
@SpringBootConfiguration
340+
static class Config {
341+
342+
@Bean
343+
public OpenAiApi chatCompletionApi() {
344+
var baseUrl = "http://%s:%d/engines".formatted(socat.getHost(), socat.getMappedPort(80));
345+
return OpenAiApi.builder().baseUrl(baseUrl).apiKey("test").build();
346+
}
347+
348+
@Bean
349+
public OpenAiChatModel openAiClient(OpenAiApi openAiApi) {
350+
return OpenAiChatModel.builder()
351+
.openAiApi(openAiApi)
352+
.defaultOptions(OpenAiChatOptions.builder().maxTokens(2048).model(DEFAULT_MODEL).build())
353+
.build();
354+
}
355+
356+
}
357+
358+
}

0 commit comments

Comments
 (0)