Skip to content

Commit 916a5c0

Browse files
committed
feat: Integrate @RequiresConsent with manual tool execution mode
- Modified DefaultToolCallingManager to check for ConsentAwareToolCallback instances - Updated ConsentAwareToolCallback to use ToolCallback interface correctly - Added tests to verify consent works in manual execution mode This ensures @RequiresConsent annotation is respected in both automatic and manual modes when internalToolExecutionEnabled=false. Signed-off-by: Hyunjoon Park <[email protected]>
1 parent aeb3b9d commit 916a5c0

File tree

3 files changed

+296
-14
lines changed

3 files changed

+296
-14
lines changed

spring-ai-model/src/main/java/org/springframework/ai/model/tool/DefaultToolCallingManager.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.ai.chat.model.ToolContext;
3535
import org.springframework.ai.chat.prompt.Prompt;
3636
import org.springframework.ai.tool.ToolCallback;
37+
import org.springframework.ai.tool.consent.ConsentAwareToolCallback;
3738
import org.springframework.ai.tool.definition.ToolDefinition;
3839
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
3940
import org.springframework.ai.tool.execution.ToolExecutionException;
@@ -217,7 +218,12 @@ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMess
217218
.observe(() -> {
218219
String toolResult;
219220
try {
220-
toolResult = toolCallback.call(toolInputArguments, toolContext);
221+
if (toolCallback instanceof ConsentAwareToolCallback consentAwareCallback) {
222+
toolResult = consentAwareCallback.call(toolInputArguments, toolContext);
223+
}
224+
else {
225+
toolResult = toolCallback.call(toolInputArguments, toolContext);
226+
}
221227
}
222228
catch (ToolExecutionException ex) {
223229
toolResult = this.toolExecutionExceptionProcessor.process(ex);

spring-ai-model/src/main/java/org/springframework/ai/tool/consent/ConsentAwareToolCallback.java

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
import java.util.regex.Matcher;
2121
import java.util.regex.Pattern;
2222

23+
import com.fasterxml.jackson.core.type.TypeReference;
24+
import com.fasterxml.jackson.databind.ObjectMapper;
25+
2326
import org.springframework.ai.tool.ToolCallback;
2427
import org.springframework.ai.tool.annotation.RequiresConsent;
2528
import org.springframework.ai.tool.consent.exception.ConsentDeniedException;
2629
import org.springframework.ai.tool.definition.ToolDefinition;
30+
import org.springframework.ai.tool.metadata.ToolMetadata;
2731
import org.springframework.util.Assert;
2832

2933
/**
@@ -59,14 +63,25 @@ public ConsentAwareToolCallback(ToolCallback delegate, ConsentManager consentMan
5963
this.requiresConsent = requiresConsent;
6064
}
6165

66+
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
67+
6268
@Override
63-
public Object call(Map<String, Object> parameters) {
64-
String toolName = getName();
69+
public String call(String toolInput) {
70+
// Parse input JSON to get parameters
71+
Map<String, Object> parameters;
72+
try {
73+
parameters = OBJECT_MAPPER.readValue(toolInput, new TypeReference<Map<String, Object>>() {
74+
});
75+
}
76+
catch (Exception e) {
77+
parameters = Map.of();
78+
}
79+
String toolName = this.delegate.getToolDefinition().name();
6580

6681
// Check if consent was already granted based on consent level
6782
if (this.consentManager.hasValidConsent(toolName, this.requiresConsent.level(),
6883
this.requiresConsent.categories())) {
69-
return this.delegate.call(parameters);
84+
return this.delegate.call(toolInput);
7085
}
7186

7287
// Prepare consent message with parameter substitution
@@ -81,22 +96,17 @@ public Object call(Map<String, Object> parameters) {
8196
}
8297

8398
// Execute the tool if consent was granted
84-
return this.delegate.call(parameters);
99+
return this.delegate.call(toolInput);
85100
}
86101

87102
@Override
88-
public String getName() {
89-
return this.delegate.getName();
90-
}
91-
92-
@Override
93-
public String getDescription() {
94-
return this.delegate.getDescription();
103+
public ToolDefinition getToolDefinition() {
104+
return this.delegate.getToolDefinition();
95105
}
96106

97107
@Override
98-
public ToolDefinition getToolDefinition() {
99-
return this.delegate.getToolDefinition();
108+
public ToolMetadata getToolMetadata() {
109+
return this.delegate.getToolMetadata();
100110
}
101111

102112
/**
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
/*
2+
* Copyright 2023-2025 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.ai.model.tool;
18+
19+
import java.util.List;
20+
import java.util.Map;
21+
22+
import io.micrometer.observation.ObservationRegistry;
23+
import org.junit.jupiter.api.BeforeEach;
24+
import org.junit.jupiter.api.Test;
25+
import org.mockito.Mockito;
26+
27+
import org.springframework.ai.chat.messages.AssistantMessage;
28+
import org.springframework.ai.chat.messages.UserMessage;
29+
import org.springframework.ai.chat.model.ChatResponse;
30+
import org.springframework.ai.chat.model.Generation;
31+
import org.springframework.ai.chat.prompt.Prompt;
32+
import org.springframework.ai.tool.ToolCallback;
33+
import org.springframework.ai.tool.annotation.RequiresConsent;
34+
import org.springframework.ai.tool.annotation.RequiresConsent.ConsentLevel;
35+
import org.springframework.ai.tool.consent.ConsentAwareToolCallback;
36+
import org.springframework.ai.tool.consent.ConsentManager;
37+
import org.springframework.ai.tool.consent.exception.ConsentDeniedException;
38+
import org.springframework.ai.tool.definition.DefaultToolDefinition;
39+
import org.springframework.ai.tool.definition.ToolDefinition;
40+
import org.springframework.ai.tool.execution.DefaultToolExecutionExceptionProcessor;
41+
import org.springframework.ai.tool.metadata.DefaultToolMetadata;
42+
import org.springframework.ai.tool.metadata.ToolMetadata;
43+
import org.springframework.ai.tool.resolution.ToolCallbackResolver;
44+
45+
import static org.assertj.core.api.Assertions.assertThat;
46+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
47+
import static org.mockito.ArgumentMatchers.any;
48+
import static org.mockito.ArgumentMatchers.anyString;
49+
import static org.mockito.Mockito.times;
50+
import static org.mockito.Mockito.verify;
51+
import static org.mockito.Mockito.when;
52+
53+
/**
54+
* Unit tests for {@link DefaultToolCallingManager} with consent management.
55+
*
56+
* @author Hyunjoon Park
57+
* @since 1.0.0
58+
*/
59+
class DefaultToolCallingManagerConsentTests {
60+
61+
private DefaultToolCallingManager toolCallingManager;
62+
63+
private ConsentManager consentManager;
64+
65+
private ToolCallback mockToolCallback;
66+
67+
private ConsentAwareToolCallback consentAwareToolCallback;
68+
69+
@BeforeEach
70+
void setUp() {
71+
ObservationRegistry observationRegistry = ObservationRegistry.create();
72+
ToolCallbackResolver toolCallbackResolver = Mockito.mock(ToolCallbackResolver.class);
73+
DefaultToolExecutionExceptionProcessor exceptionProcessor = DefaultToolExecutionExceptionProcessor.builder()
74+
.build();
75+
76+
this.toolCallingManager = new DefaultToolCallingManager(observationRegistry, toolCallbackResolver,
77+
exceptionProcessor);
78+
79+
// Set up mock tool callback
80+
this.mockToolCallback = Mockito.mock(ToolCallback.class);
81+
ToolDefinition toolDefinition = DefaultToolDefinition.builder()
82+
.name("deleteBook")
83+
.description("Delete a book")
84+
.inputSchema("{\"type\":\"object\",\"properties\":{\"bookId\":{\"type\":\"string\"}}}")
85+
.build();
86+
ToolMetadata toolMetadata = DefaultToolMetadata.builder().build();
87+
when(this.mockToolCallback.getToolDefinition()).thenReturn(toolDefinition);
88+
when(this.mockToolCallback.getToolMetadata()).thenReturn(toolMetadata);
89+
90+
// Set up consent manager
91+
this.consentManager = Mockito.mock(ConsentManager.class);
92+
93+
// Create mock RequiresConsent annotation
94+
RequiresConsent requiresConsent = Mockito.mock(RequiresConsent.class);
95+
when(requiresConsent.message()).thenReturn("Delete book {bookId}?");
96+
when(requiresConsent.level()).thenReturn(ConsentLevel.EVERY_TIME);
97+
when(requiresConsent.categories()).thenReturn(new String[0]);
98+
99+
// Create consent-aware wrapper
100+
this.consentAwareToolCallback = new ConsentAwareToolCallback(this.mockToolCallback, this.consentManager,
101+
requiresConsent);
102+
}
103+
104+
@Test
105+
void testManualExecutionWithConsentGranted() {
106+
// Given
107+
// ConsentAwareToolCallback will first check hasValidConsent, then call
108+
// requestConsent if needed
109+
// For this test, we'll make hasValidConsent return false to trigger
110+
// requestConsent
111+
when(this.consentManager.hasValidConsent(anyString(), any(ConsentLevel.class), any(String[].class)))
112+
.thenReturn(false);
113+
when(this.consentManager.requestConsent(anyString(), anyString(), any(ConsentLevel.class), any(String[].class),
114+
any(Map.class)))
115+
.thenReturn(true);
116+
when(this.mockToolCallback.call(anyString(), any())).thenReturn("Book deleted");
117+
118+
List<ToolCallback> toolCallbacks = List.of(this.consentAwareToolCallback);
119+
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolCallbacks(toolCallbacks).build();
120+
121+
UserMessage userMessage = new UserMessage("Delete book with ID 123");
122+
Prompt prompt = new Prompt(List.of(userMessage), chatOptions);
123+
124+
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "tool-call", "deleteBook",
125+
"{\"bookId\":\"123\"}");
126+
AssistantMessage assistantMessage = new AssistantMessage("I'll delete the book.", Map.of(), List.of(toolCall));
127+
128+
Generation generation = new Generation(assistantMessage);
129+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
130+
131+
// When
132+
ToolExecutionResult result = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
133+
134+
// Then
135+
assertThat(result).isNotNull();
136+
assertThat(result.conversationHistory()).hasSize(3); // user, assistant, tool
137+
// response
138+
// Verify consent was requested
139+
verify(this.consentManager, times(1)).hasValidConsent(anyString(), any(ConsentLevel.class),
140+
any(String[].class));
141+
verify(this.consentManager, times(1)).requestConsent(anyString(), anyString(), any(ConsentLevel.class),
142+
any(String[].class), any(Map.class));
143+
verify(this.mockToolCallback, times(1)).call(anyString(), any());
144+
}
145+
146+
@Test
147+
void testManualExecutionWithConsentDenied() {
148+
// Given
149+
// ConsentAwareToolCallback will first check hasValidConsent, then call
150+
// requestConsent if needed
151+
when(this.consentManager.hasValidConsent(anyString(), any(ConsentLevel.class), any(String[].class)))
152+
.thenReturn(false);
153+
when(this.consentManager.requestConsent(anyString(), anyString(), any(ConsentLevel.class), any(String[].class),
154+
any(Map.class)))
155+
.thenReturn(false);
156+
157+
List<ToolCallback> toolCallbacks = List.of(this.consentAwareToolCallback);
158+
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolCallbacks(toolCallbacks).build();
159+
160+
UserMessage userMessage = new UserMessage("Delete book with ID 123");
161+
Prompt prompt = new Prompt(List.of(userMessage), chatOptions);
162+
163+
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "tool-call", "deleteBook",
164+
"{\"bookId\":\"123\"}");
165+
AssistantMessage assistantMessage = new AssistantMessage("I'll delete the book.", Map.of(), List.of(toolCall));
166+
167+
Generation generation = new Generation(assistantMessage);
168+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
169+
170+
// When & Then
171+
assertThatThrownBy(() -> this.toolCallingManager.executeToolCalls(prompt, chatResponse))
172+
.isInstanceOf(ConsentDeniedException.class)
173+
.hasMessageContaining("User denied consent for tool");
174+
175+
// Verify consent was requested but denied
176+
verify(this.consentManager, times(1)).hasValidConsent(anyString(), any(ConsentLevel.class),
177+
any(String[].class));
178+
verify(this.consentManager, times(1)).requestConsent(anyString(), anyString(), any(ConsentLevel.class),
179+
any(String[].class), any(Map.class));
180+
verify(this.mockToolCallback, times(0)).call(anyString(), any());
181+
}
182+
183+
@Test
184+
void testManualExecutionWithNonConsentAwareToolCallback() {
185+
// Given
186+
when(this.mockToolCallback.call(anyString(), any())).thenReturn("Book deleted");
187+
188+
List<ToolCallback> toolCallbacks = List.of(this.mockToolCallback); // Regular
189+
// callback,
190+
// not
191+
// consent-aware
192+
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolCallbacks(toolCallbacks).build();
193+
194+
UserMessage userMessage = new UserMessage("Delete book with ID 123");
195+
Prompt prompt = new Prompt(List.of(userMessage), chatOptions);
196+
197+
AssistantMessage.ToolCall toolCall = new AssistantMessage.ToolCall("1", "tool-call", "deleteBook",
198+
"{\"bookId\":\"123\"}");
199+
AssistantMessage assistantMessage = new AssistantMessage("I'll delete the book.", Map.of(), List.of(toolCall));
200+
201+
Generation generation = new Generation(assistantMessage);
202+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
203+
204+
// When
205+
ToolExecutionResult result = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
206+
207+
// Then
208+
assertThat(result).isNotNull();
209+
assertThat(result.conversationHistory()).hasSize(3);
210+
verify(this.mockToolCallback, times(1)).call(anyString(), any());
211+
// ConsentManager should not be called for non-consent-aware callbacks
212+
}
213+
214+
@Test
215+
void testManualExecutionWithMixedToolCallbacks() {
216+
// Given
217+
ToolCallback regularCallback = Mockito.mock(ToolCallback.class);
218+
ToolDefinition regularToolDef = DefaultToolDefinition.builder()
219+
.name("getBook")
220+
.description("Get a book")
221+
.inputSchema("{\"type\":\"object\",\"properties\":{\"bookId\":{\"type\":\"string\"}}}")
222+
.build();
223+
when(regularCallback.getToolDefinition()).thenReturn(regularToolDef);
224+
when(regularCallback.getToolMetadata()).thenReturn(DefaultToolMetadata.builder().build());
225+
when(regularCallback.call(anyString(), any())).thenReturn("Book found");
226+
227+
// For the consent-aware callback
228+
when(this.consentManager.hasValidConsent(anyString(), any(ConsentLevel.class), any(String[].class)))
229+
.thenReturn(false);
230+
when(this.consentManager.requestConsent(anyString(), anyString(), any(ConsentLevel.class), any(String[].class),
231+
any(Map.class)))
232+
.thenReturn(true);
233+
when(this.mockToolCallback.call(anyString(), any())).thenReturn("Book deleted");
234+
235+
List<ToolCallback> toolCallbacks = List.of(regularCallback, this.consentAwareToolCallback);
236+
ToolCallingChatOptions chatOptions = ToolCallingChatOptions.builder().toolCallbacks(toolCallbacks).build();
237+
238+
UserMessage userMessage = new UserMessage("Get and delete book with ID 123");
239+
Prompt prompt = new Prompt(List.of(userMessage), chatOptions);
240+
241+
AssistantMessage.ToolCall getCall = new AssistantMessage.ToolCall("1", "tool-call", "getBook",
242+
"{\"bookId\":\"123\"}");
243+
AssistantMessage.ToolCall deleteCall = new AssistantMessage.ToolCall("2", "tool-call", "deleteBook",
244+
"{\"bookId\":\"123\"}");
245+
AssistantMessage assistantMessage = new AssistantMessage("I'll get and delete the book.", Map.of(),
246+
List.of(getCall, deleteCall));
247+
248+
Generation generation = new Generation(assistantMessage);
249+
ChatResponse chatResponse = new ChatResponse(List.of(generation));
250+
251+
// When
252+
ToolExecutionResult result = this.toolCallingManager.executeToolCalls(prompt, chatResponse);
253+
254+
// Then
255+
assertThat(result).isNotNull();
256+
assertThat(result.conversationHistory()).hasSize(3);
257+
verify(regularCallback, times(1)).call(anyString(), any());
258+
// Verify consent was requested for the consent-aware callback only
259+
verify(this.consentManager, times(1)).hasValidConsent(anyString(), any(ConsentLevel.class),
260+
any(String[].class));
261+
verify(this.consentManager, times(1)).requestConsent(anyString(), anyString(), any(ConsentLevel.class),
262+
any(String[].class), any(Map.class));
263+
verify(this.mockToolCallback, times(1)).call(anyString(), any());
264+
}
265+
266+
}

0 commit comments

Comments
 (0)