Skip to content

Refactor tool #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: feature/vertexai-tool-invocation
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

### Added

- Added tool calling functionality to the LLM base class with OpenAI implementation, enabling structured parameter extraction and function calling.
- Added tool calling functionality to the LLM base class with OpenAI and VertexAI implementations, enabling structured parameter extraction and function calling.
- Added support for multi-vector collection in Qdrant driver.
- Added a `Pipeline.stream` method to stream pipeline progress.
- Added a new semantic match resolver to the KG Builder for entity resolution based on spaCy embeddings and cosine similarities so that nodes with similar textual properties get merged.
Expand All @@ -13,7 +13,7 @@
### Changed

- Improved log output readability in Retrievers and GraphRAG and added embedded vector to retriever result metadata for debugging.
- Switched from pygraphviz to neo4j-viz
- Switched from pygraphviz to neo4j-viz
- Renders interactive graph now on HTML instead of PNG
- Removed `get_pygraphviz_graph` method

Expand Down
199 changes: 44 additions & 155 deletions src/neo4j_graphrag/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from abc import ABC
from enum import Enum
from typing import Any, Dict, List, Callable, Optional, Union, ClassVar
from pydantic import BaseModel, Field, model_validator
from typing import Any, Dict, List, Callable, Optional, Union, Literal, Annotated
from pydantic import BaseModel, Field, ConfigDict, AliasGenerator
from pydantic.alias_generators import to_camel, to_snake


class ParameterType(str, Enum):
Expand All @@ -19,193 +20,81 @@ class ToolParameter(BaseModel):
"""Base class for all tool parameters using Pydantic."""

description: str
required: bool = False
type: ClassVar[ParameterType]

def model_dump_tool(self) -> Dict[str, Any]:
"""Convert the parameter to a dictionary format for tool usage."""
result: Dict[str, Any] = {"type": self.type, "description": self.description}
if self.required:
result["required"] = True
return result

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ToolParameter":
"""Create a parameter from a dictionary."""
param_type = data.get("type")
if not param_type:
raise ValueError("Parameter type is required")

# Find the appropriate class based on the type
param_classes = {
ParameterType.STRING: StringParameter,
ParameterType.INTEGER: IntegerParameter,
ParameterType.NUMBER: NumberParameter,
ParameterType.BOOLEAN: BooleanParameter,
ParameterType.OBJECT: ObjectParameter,
ParameterType.ARRAY: ArrayParameter,
}

param_class = param_classes.get(param_type)
if not param_class:
raise ValueError(f"Unknown parameter type: {param_type}")

# Use type ignore since mypy doesn't understand dynamic class instantiation
return param_class.model_validate(data) # type: ignore
type: ParameterType


class StringParameter(ToolParameter):
"""String parameter for tools."""

type: ClassVar[ParameterType] = ParameterType.STRING
type: Literal[ParameterType.STRING] = ParameterType.STRING
enum: Optional[List[str]] = None

def model_dump_tool(self) -> Dict[str, Any]:
result = super().model_dump_tool()
if self.enum:
result["enum"] = self.enum
return result


class IntegerParameter(ToolParameter):
"""Integer parameter for tools."""

type: ClassVar[ParameterType] = ParameterType.INTEGER
type: Literal[ParameterType.INTEGER] = ParameterType.INTEGER
minimum: Optional[int] = None
maximum: Optional[int] = None

def model_dump_tool(self) -> Dict[str, Any]:
result = super().model_dump_tool()
if self.minimum is not None:
result["minimum"] = self.minimum
if self.maximum is not None:
result["maximum"] = self.maximum
return result


class NumberParameter(ToolParameter):
"""Number parameter for tools."""

type: ClassVar[ParameterType] = ParameterType.NUMBER
type: Literal[ParameterType.NUMBER] = ParameterType.NUMBER
minimum: Optional[float] = None
maximum: Optional[float] = None

def model_dump_tool(self) -> Dict[str, Any]:
result = super().model_dump_tool()
if self.minimum is not None:
result["minimum"] = self.minimum
if self.maximum is not None:
result["maximum"] = self.maximum
return result


class BooleanParameter(ToolParameter):
"""Boolean parameter for tools."""

type: ClassVar[ParameterType] = ParameterType.BOOLEAN
type: Literal[ParameterType.BOOLEAN] = ParameterType.BOOLEAN


class ArrayParameter(ToolParameter):
"""Array parameter for tools."""

type: ClassVar[ParameterType] = ParameterType.ARRAY
items: "ToolParameter"
type: Literal[ParameterType.ARRAY] = ParameterType.ARRAY
items: "AnyToolParameter"
min_items: Optional[int] = None
max_items: Optional[int] = None

@model_validator(mode="before")
@classmethod
def _preprocess_items(cls, values: dict[str, Any]) -> dict[str, Any]:
# Convert items from dict to ToolParameter if needed
items = values.get("items")
if isinstance(items, dict):
values["items"] = ToolParameter.from_dict(items)
return values

def model_dump_tool(self) -> Dict[str, Any]:
result = super().model_dump_tool()
result["items"] = self.items.model_dump_tool()
if self.min_items is not None:
result["minItems"] = self.min_items
if self.max_items is not None:
result["maxItems"] = self.max_items
return result

@model_validator(mode="after")
def validate_items(self) -> "ArrayParameter":
if not isinstance(self.items, ToolParameter):
if isinstance(self.items, dict):
self.items = ToolParameter.from_dict(self.items)
else:
raise ValueError(
f"Items must be a ToolParameter or dict, got {type(self.items)}"
)
elif type(self.items) is ToolParameter:
# Promote base ToolParameter to correct subclass if possible
self.items = ToolParameter.from_dict(self.items.model_dump())
return self
model_config = ConfigDict(
alias_generator=AliasGenerator(
validation_alias=to_snake,
serialization_alias=to_camel,
)
)


class ObjectParameter(ToolParameter):
"""Object parameter for tools."""

type: ClassVar[ParameterType] = ParameterType.OBJECT
properties: Dict[str, ToolParameter]
required_properties: List[str] = Field(default_factory=list)
type: Literal[ParameterType.OBJECT] = ParameterType.OBJECT
properties: Dict[str, "AnyToolParameter"]
required: List[str] = []
additional_properties: bool = True

@model_validator(mode="before")
@classmethod
def _preprocess_properties(cls, values: dict[str, Any]) -> dict[str, Any]:
# Convert properties from dicts to ToolParameter if needed
props = values.get("properties")
if isinstance(props, dict):
new_props = {}
for k, v in props.items():
if isinstance(v, dict):
new_props[k] = ToolParameter.from_dict(v)
else:
new_props[k] = v
values["properties"] = new_props
return values

def model_dump_tool(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
exclude = exclude or []
properties_dict: Dict[str, Any] = {}
for name, param in self.properties.items():
if name in exclude:
continue
properties_dict[name] = param.model_dump_tool()

result = super().model_dump_tool()
result["properties"] = properties_dict

if self.required_properties and "required" not in exclude:
result["required"] = self.required_properties

if not self.additional_properties and "additional_properties" not in exclude:
result["additionalProperties"] = False

return result

@model_validator(mode="after")
def validate_properties(self) -> "ObjectParameter":
validated_properties = {}
for name, param in self.properties.items():
if not isinstance(param, ToolParameter):
if isinstance(param, dict):
validated_properties[name] = ToolParameter.from_dict(param)
else:
raise ValueError(
f"Property {name} must be a ToolParameter or dict, got {type(param)}"
)
elif type(param) is ToolParameter:
# Promote base ToolParameter to correct subclass if possible
validated_properties[name] = ToolParameter.from_dict(param.model_dump())
else:
validated_properties[name] = param
self.properties = validated_properties
return self
model_config = ConfigDict(
alias_generator=AliasGenerator(
validation_alias=to_snake,
serialization_alias=to_camel,
)
)


AnyToolParameter = Annotated[
Union[
StringParameter,
IntegerParameter,
NumberParameter,
BooleanParameter,
ObjectParameter,
ArrayParameter,
],
Field(discriminator="type"),
]


class Tool(ABC):
Expand All @@ -222,11 +111,7 @@ def __init__(
self._description = description

# Allow parameters to be provided as a dictionary
if isinstance(parameters, dict):
self._parameters = ObjectParameter.model_validate(parameters)
else:
self._parameters = parameters

self._parameters = ObjectParameter.model_validate(parameters)
self._execute_func = execute_func

def get_name(self) -> str:
Expand All @@ -251,7 +136,11 @@ def get_parameters(self, exclude: Optional[list[str]] = None) -> Dict[str, Any]:
Returns:
Dict[str, Any]: Dictionary containing parameter schema information.
"""
return self._parameters.model_dump_tool(exclude)
return self._parameters.model_dump(
by_alias=True, # camelCase
exclude_none=True, # exclude None values
exclude=exclude, # exclude any specific field
)

def execute(self, query: str, **kwargs: Any) -> Any:
"""Execute the tool with the given query and additional parameters.
Expand Down
Loading
Loading