Skip to content

Add tool calling to the LLM base class, implement in OpenAI #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 24, 2025
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

### Added

- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling.
- Added support for multi-vector collection in Qdrant driver.
- Added a `Pipeline.stream` method to stream pipeline progress.
- Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged.
Expand Down
2 changes: 2 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ are listed in [the last section of this file](#customize).
- [Message history with Neo4j](./customize/llms/llm_with_neo4j_message_history.py)
- [System Instruction](./customize/llms/llm_with_system_instructions.py)

- [Tool Calling with OpenAI](./customize/llms/openai_tool_calls.py)


### Prompts

Expand Down
101 changes: 101 additions & 0 deletions examples/customize/llms/openai_tool_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""
Example showing how to use OpenAI tool calls with parameter extraction.
Both synchronous and asynchronous examples are provided.

To run this example:
1. Make sure you have the OpenAI API key in your .env file:
OPENAI_API_KEY=your-api-key
2. Run: python examples/tool_calls/openai_tool_calls.py
"""

import asyncio
import json
import os
from typing import Dict, Any

from dotenv import load_dotenv

from neo4j_graphrag.llm import OpenAILLM
from neo4j_graphrag.llm.types import ToolCallResponse
from neo4j_graphrag.tool import Tool, ObjectParameter, StringParameter, IntegerParameter

# Load environment variables from .env file (OPENAI_API_KEY required for this example)
load_dotenv()


# Create a custom Tool implementation for person info extraction
parameters = ObjectParameter(
description="Parameters for extracting person information",
properties={
"name": StringParameter(description="The person's full name"),
"age": IntegerParameter(description="The person's age"),
"occupation": StringParameter(description="The person's occupation"),
},
required_properties=["name"],
additional_properties=False,
)
person_info_tool = Tool(
name="extract_person_info",
description="Extract information about a person from text",
parameters=parameters,
execute_func=lambda **kwargs: kwargs,
)

# Create the tool instance
TOOLS = [person_info_tool]


def process_tool_calls(response: ToolCallResponse) -> Dict[str, Any]:
"""Process all tool calls in the response and return the extracted parameters."""
if not response.tool_calls:
raise ValueError("No tool calls found in response")

print(f"\nNumber of tool calls: {len(response.tool_calls)}")
print(f"Additional content: {response.content or 'None'}")

results = []
for i, tool_call in enumerate(response.tool_calls):
print(f"\nTool call #{i + 1}: {tool_call.name}")
print(f"Arguments: {tool_call.arguments}")
results.append(tool_call.arguments)

# For backward compatibility, return the first tool call's arguments
return results[0] if results else {}


async def main() -> None:
# Initialize the OpenAI LLM
llm = OpenAILLM(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="gpt-4o",
model_params={"temperature": 0},
)

# Example text containing information about a person
text = "Stella Hane is a 35-year-old software engineer who loves coding."

print("\n=== Synchronous Tool Call ===")
# Make a synchronous tool call
sync_response = llm.invoke_with_tools(
input=f"Extract information about the person from this text: {text}",
tools=TOOLS,
)
sync_result = process_tool_calls(sync_response)
print("\n=== Synchronous Tool Call Result ===")
print(json.dumps(sync_result, indent=2))

print("\n=== Asynchronous Tool Call ===")
# Make an asynchronous tool call with a different text
text2 = "Molly Hane, 32, works as a data scientist and enjoys machine learning."
async_response = await llm.ainvoke_with_tools(
input=f"Extract information about the person from this text: {text2}",
tools=TOOLS,
)
async_result = process_tool_calls(async_response)
print("\n=== Asynchronous Tool Call Result ===")
print(json.dumps(async_result, indent=2))


if __name__ == "__main__":
# Run the async main function
asyncio.run(main())
60 changes: 58 additions & 2 deletions src/neo4j_graphrag/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Sequence, Union

from neo4j_graphrag.message_history import MessageHistory
from neo4j_graphrag.types import LLMMessage

from .types import LLMResponse
from .types import LLMResponse, ToolCallResponse

from neo4j_graphrag.tool import Tool


class LLMInterface(ABC):
Expand Down Expand Up @@ -84,3 +86,57 @@ async def ainvoke(
Raises:
LLMGenerationError: If anything goes wrong.
"""

def invoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
"""Sends a text input to the LLM with tool definitions and retrieves a tool call response.

This is a default implementation that should be overridden by LLM providers that support tool/function calling.

Args:
input (str): Text sent to the LLM.
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invocation.

Returns:
ToolCallResponse: The response from the LLM containing a tool call.

Raises:
LLMGenerationError: If anything goes wrong.
NotImplementedError: If the LLM provider does not support tool calling.
"""
raise NotImplementedError("This LLM provider does not support tool calling.")

async def ainvoke_with_tools(
self,
input: str,
tools: Sequence[Tool],
message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None,
system_instruction: Optional[str] = None,
) -> ToolCallResponse:
"""Asynchronously sends a text input to the LLM with tool definitions and retrieves a tool call response.

This is a default implementation that should be overridden by LLM providers that support tool/function calling.

Args:
input (str): Text sent to the LLM.
tools (Sequence[Tool]): Sequence of Tools for the LLM to choose from. Each LLM implementation should handle the conversion to its specific format.
message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages,
with each message having a specific role assigned.
system_instruction (Optional[str]): An option to override the llm system message for this invocation.

Returns:
ToolCallResponse: The response from the LLM containing a tool call.

Raises:
LLMGenerationError: If anything goes wrong.
NotImplementedError: If the LLM provider does not support tool calling.
"""
raise NotImplementedError("This LLM provider does not support tool calling.")
Loading