Skip to content

Commit e0b5aca

Browse files
Implement tool calling for Ollama and include examples
1 parent 4d21e8a commit e0b5aca

File tree

5 files changed

+425
-1
lines changed

5 files changed

+425
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
- Added automatic rate limiting with retry logic and exponential backoff for all Embedding providers using tenacity. The `RateLimitHandler` interface allows for custom rate limiting strategies, including the ability to disable rate limiting entirely.
88
- JSON response returned to `SchemaFromTextExtractor` is cleansed of any markdown code blocks before being loaded.
9+
- Tool calling support for Ollama in LLMInterface.
910

1011
## 1.10.0
1112

examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ are listed in [the last section of this file](#customize).
8181

8282
- [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py)
8383
- [Tool Calling with VertexAI](./customize/llms/vertexai_tool_calls.py)
84+
- [Tool Calling with Ollama](./customize/llms/ollama_tool_calls.py)
8485

8586

8687
### Prompts
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""
2+
Example showing how to use Ollama tool calls with parameter extraction.
3+
Both synchronous and asynchronous examples are provided.
4+
5+
To run this example:
6+
1. Make sure you have `ollama serve` running
7+
2. Run: python examples/tool_calls/ollama_tool_calls.py
8+
"""
9+
10+
import asyncio
11+
import json
12+
from typing import Dict, Any
13+
14+
from neo4j_graphrag.llm import OllamaLLM
15+
from neo4j_graphrag.llm.types import ToolCallResponse
16+
from neo4j_graphrag.tool import (
17+
Tool,
18+
ObjectParameter,
19+
StringParameter,
20+
IntegerParameter,
21+
)
22+
23+
24+
# Create a custom Tool implementation for person info extraction
25+
parameters = ObjectParameter(
26+
description="Parameters for extracting person information",
27+
properties={
28+
"name": StringParameter(description="The person's full name"),
29+
"age": IntegerParameter(description="The person's age"),
30+
"occupation": StringParameter(description="The person's occupation"),
31+
},
32+
required_properties=["name"],
33+
additional_properties=False,
34+
)
35+
person_info_tool = Tool(
36+
name="extract_person_info",
37+
description="Extract information about a person from text",
38+
parameters=parameters,
39+
execute_func=lambda **kwargs: kwargs,
40+
)
41+
42+
# Create the tool instance
43+
TOOLS = [person_info_tool]
44+
45+
46+
def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]:
47+
"""Process all tool calls in the response and return the extracted parameters."""
48+
if not response.tool_calls:
49+
raise ValueError("No tool calls found in response")
50+
51+
print(f"\nNumber of tool calls: {len(response.tool_calls)}")
52+
print(f"Additional content: {response.content or 'None'}")
53+
54+
results = []
55+
for i, tool_call in enumerate(response.tool_calls):
56+
print(f"\nTool call #{i + 1}: {tool_call.name}")
57+
print(f"Arguments: {tool_call.arguments}")
58+
results.append(tool_call.arguments)
59+
60+
# For backward compatibility, return the first tool call's arguments
61+
return results[0] if results else {}
62+
63+
64+
async def main() -> None:
65+
# Initialize the Ollama LLM
66+
llm = OllamaLLM(
67+
# model_name="gpt-4o",
68+
model_name="mistral:latest",
69+
model_params={"temperature": 0},
70+
)
71+
72+
# Example text containing information about a person
73+
text = "Stella Hane is a 35-year-old software engineer who loves coding."
74+
75+
print("\n=== Synchronous Tool Call ===")
76+
# Make a synchronous tool call
77+
sync_response = llm.invoke_with_tools(
78+
input=f"Extract information about the person from this text: {text}",
79+
tools=TOOLS,
80+
)
81+
sync_result = process_tool_calls(sync_response)
82+
print("\n=== Synchronous Tool Call Result ===")
83+
print(json.dumps(sync_result, indent=2))
84+
85+
print("\n=== Asynchronous Tool Call ===")
86+
# Make an asynchronous tool call with a different text
87+
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
88+
async_response = await llm.ainvoke_with_tools(
89+
input=f"Extract information about the person from this text: {text2}",
90+
tools=TOOLS,
91+
)
92+
async_result = process_tool_calls(async_response)
93+
print("\n=== Asynchronous Tool Call Result ===")
94+
print(json.dumps(async_result, indent=2))
95+
96+
97+
if __name__ == "__main__":
98+
# Run the async main function
99+
asyncio.run(main())

src/neo4j_graphrag/llm/ollama_llm.py

Lines changed: 157 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,17 @@
1515
from __future__ import annotations
1616

1717
import warnings
18-
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast
18+
from typing import (
19+
TYPE_CHECKING,
20+
Any,
21+
Iterable,
22+
List,
23+
Optional,
24+
Sequence,
25+
Union,
26+
cast,
27+
Dict,
28+
)
1929

2030
from pydantic import ValidationError
2131

@@ -33,9 +43,12 @@
3343
BaseMessage,
3444
LLMResponse,
3545
MessageList,
46+
ToolCall,
47+
ToolCallResponse,
3648
SystemMessage,
3749
UserMessage,
3850
)
51+
from neo4j_graphrag.tool import Tool
3952

4053
if TYPE_CHECKING:
4154
from ollama import Message
@@ -163,3 +176,146 @@ async def ainvoke(
163176
return LLMResponse(content=content)
164177
except self.ollama.ResponseError as e:
165178
raise LLMGenerationError(e)
179+
180+
@rate_limit_handler
181+
def invoke_with_tools(
182+
self,
183+
input: str,
184+
tools: Sequence[Tool], # Tools definition as a sequence of Tool objects
185+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
186+
system_instruction: Optional[str] = None,
187+
) -> ToolCallResponse:
188+
"""Sends a text input to the LLM with tool definitions
189+
and retrieves a tool call response.
190+
191+
Args:
192+
input (str): Text sent to the LLM.
193+
tools (List[Tool]): List of Tools for the LLM to choose from.
194+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
195+
with each message having a specific role assigned.
196+
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
197+
198+
Returns:
199+
ToolCallResponse: The response from the LLM containing a tool call.
200+
201+
Raises:
202+
LLMGenerationError: If anything goes wrong.
203+
"""
204+
try:
205+
if isinstance(message_history, MessageHistory):
206+
message_history = message_history.messages
207+
208+
# Convert tools to Ollama's expected type
209+
ollama_tools = []
210+
for tool in tools:
211+
ollama_tool_format = self._convert_tool_to_ollama_format(tool)
212+
ollama_tools.append(ollama_tool_format)
213+
response = self.client.chat(
214+
model=self.model_name,
215+
messages=self.get_messages(input, message_history, system_instruction),
216+
tools=ollama_tools,
217+
**self.model_params,
218+
)
219+
message = response.message
220+
# If there's no tool call, return the content as a regular response
221+
if not message.tool_calls or len(message.tool_calls) == 0:
222+
return ToolCallResponse(
223+
tool_calls=[],
224+
content=message.content,
225+
)
226+
227+
# Process all tool calls
228+
tool_calls = []
229+
230+
for tool_call in message.tool_calls:
231+
args = tool_call.function.arguments
232+
tool_calls.append(
233+
ToolCall(name=tool_call.function.name, arguments=args)
234+
)
235+
236+
return ToolCallResponse(tool_calls=tool_calls, content=message.content)
237+
except self.ollama.ResponseError as e:
238+
raise LLMGenerationError(e)
239+
240+
@async_rate_limit_handler
241+
async def ainvoke_with_tools(
242+
self,
243+
input: str,
244+
tools: Sequence[Tool], # Tools definition as a sequence of Tool objects
245+
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
246+
system_instruction: Optional[str] = None,
247+
) -> ToolCallResponse:
248+
"""Sends a text input to the LLM with tool definitions
249+
and retrieves a tool call response.
250+
251+
Args:
252+
input (str): Text sent to the LLM.
253+
tools (List[Tool]): List of Tools for the LLM to choose from.
254+
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
255+
with each message having a specific role assigned.
256+
system_instruction (Optional[str]): An option to override the llm system message for this invocation.
257+
258+
Returns:
259+
ToolCallResponse: The response from the LLM containing a tool call.
260+
261+
Raises:
262+
LLMGenerationError: If anything goes wrong.
263+
"""
264+
try:
265+
if isinstance(message_history, MessageHistory):
266+
message_history = message_history.messages
267+
268+
# Convert tools to Ollama's expected type
269+
ollama_tools = []
270+
for tool in tools:
271+
ollama_tool_format = self._convert_tool_to_ollama_format(tool)
272+
ollama_tools.append(ollama_tool_format)
273+
274+
response = await self.async_client.chat(
275+
model=self.model_name,
276+
messages=self.get_messages(input, message_history, system_instruction),
277+
tools=ollama_tools,
278+
**self.model_params,
279+
)
280+
message = response.message
281+
282+
# If there's no tool call, return the content as a regular response
283+
if not message.tool_calls or len(message.tool_calls) == 0:
284+
return ToolCallResponse(
285+
tool_calls=[],
286+
content=message.content,
287+
)
288+
289+
# Process all tool calls
290+
tool_calls = []
291+
292+
for tool_call in message.tool_calls:
293+
args = tool_call.function.arguments
294+
tool_calls.append(
295+
ToolCall(name=tool_call.function.name, arguments=args)
296+
)
297+
298+
return ToolCallResponse(tool_calls=tool_calls, content=message.content)
299+
except self.ollama.ResponseError as e:
300+
raise LLMGenerationError(e)
301+
302+
def _convert_tool_to_ollama_format(self, tool: Tool) -> Dict[str, Any]:
303+
"""Convert a Tool object to Ollama's expected format.
304+
305+
Args:
306+
tool: A Tool object to convert to Ollama's format.
307+
308+
Returns:
309+
A dictionary in Ollama's tool format.
310+
"""
311+
try:
312+
return {
313+
"type": "function",
314+
"function": {
315+
"name": tool.get_name(),
316+
"description": tool.get_description(),
317+
"parameters": tool.get_parameters(),
318+
},
319+
}
320+
except AttributeError:
321+
raise LLMGenerationError(f"Tool {tool} is not a valid Tool object")

0 commit comments

Comments
 (0)