|  | 
| 15 | 15 | from __future__ import annotations | 
| 16 | 16 | 
 | 
| 17 | 17 | import warnings | 
| 18 |  | -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Sequence, Union, cast | 
|  | 18 | +from typing import ( | 
|  | 19 | +    TYPE_CHECKING, | 
|  | 20 | +    Any, | 
|  | 21 | +    Iterable, | 
|  | 22 | +    List, | 
|  | 23 | +    Optional, | 
|  | 24 | +    Sequence, | 
|  | 25 | +    Union, | 
|  | 26 | +    cast, | 
|  | 27 | +    Dict, | 
|  | 28 | +) | 
| 19 | 29 | 
 | 
| 20 | 30 | from pydantic import ValidationError | 
| 21 | 31 | 
 | 
|  | 
| 33 | 43 |     BaseMessage, | 
| 34 | 44 |     LLMResponse, | 
| 35 | 45 |     MessageList, | 
|  | 46 | +    ToolCall, | 
|  | 47 | +    ToolCallResponse, | 
| 36 | 48 |     SystemMessage, | 
| 37 | 49 |     UserMessage, | 
| 38 | 50 | ) | 
|  | 51 | +from neo4j_graphrag.tool import Tool | 
| 39 | 52 | 
 | 
| 40 | 53 | if TYPE_CHECKING: | 
| 41 | 54 |     from ollama import Message | 
| @@ -163,3 +176,146 @@ async def ainvoke( | 
| 163 | 176 |             return LLMResponse(content=content) | 
| 164 | 177 |         except self.ollama.ResponseError as e: | 
| 165 | 178 |             raise LLMGenerationError(e) | 
|  | 179 | + | 
|  | 180 | +    @rate_limit_handler | 
|  | 181 | +    def invoke_with_tools( | 
|  | 182 | +        self, | 
|  | 183 | +        input: str, | 
|  | 184 | +        tools: Sequence[Tool],  # Tools definition as a sequence of Tool objects | 
|  | 185 | +        message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, | 
|  | 186 | +        system_instruction: Optional[str] = None, | 
|  | 187 | +    ) -> ToolCallResponse: | 
|  | 188 | +        """Sends a text input to the LLM with tool definitions | 
|  | 189 | +        and retrieves a tool call response. | 
|  | 190 | +
 | 
|  | 191 | +        Args: | 
|  | 192 | +            input (str): Text sent to the LLM. | 
|  | 193 | +            tools (List[Tool]): List of Tools for the LLM to choose from. | 
|  | 194 | +            message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, | 
|  | 195 | +                with each message having a specific role assigned. | 
|  | 196 | +            system_instruction (Optional[str]): An option to override the llm system message for this invocation. | 
|  | 197 | +
 | 
|  | 198 | +        Returns: | 
|  | 199 | +            ToolCallResponse: The response from the LLM containing a tool call. | 
|  | 200 | +
 | 
|  | 201 | +        Raises: | 
|  | 202 | +            LLMGenerationError: If anything goes wrong. | 
|  | 203 | +        """ | 
|  | 204 | +        try: | 
|  | 205 | +            if isinstance(message_history, MessageHistory): | 
|  | 206 | +                message_history = message_history.messages | 
|  | 207 | + | 
|  | 208 | +            # Convert tools to Ollama's expected type | 
|  | 209 | +            ollama_tools = [] | 
|  | 210 | +            for tool in tools: | 
|  | 211 | +                ollama_tool_format = self._convert_tool_to_ollama_format(tool) | 
|  | 212 | +                ollama_tools.append(ollama_tool_format) | 
|  | 213 | +            response = self.client.chat( | 
|  | 214 | +                model=self.model_name, | 
|  | 215 | +                messages=self.get_messages(input, message_history, system_instruction), | 
|  | 216 | +                tools=ollama_tools, | 
|  | 217 | +                **self.model_params, | 
|  | 218 | +            ) | 
|  | 219 | +            message = response.message | 
|  | 220 | +            # If there's no tool call, return the content as a regular response | 
|  | 221 | +            if not message.tool_calls or len(message.tool_calls) == 0: | 
|  | 222 | +                return ToolCallResponse( | 
|  | 223 | +                    tool_calls=[], | 
|  | 224 | +                    content=message.content, | 
|  | 225 | +                ) | 
|  | 226 | + | 
|  | 227 | +            # Process all tool calls | 
|  | 228 | +            tool_calls = [] | 
|  | 229 | + | 
|  | 230 | +            for tool_call in message.tool_calls: | 
|  | 231 | +                args = tool_call.function.arguments | 
|  | 232 | +                tool_calls.append( | 
|  | 233 | +                    ToolCall(name=tool_call.function.name, arguments=args) | 
|  | 234 | +                ) | 
|  | 235 | + | 
|  | 236 | +            return ToolCallResponse(tool_calls=tool_calls, content=message.content) | 
|  | 237 | +        except self.ollama.ResponseError as e: | 
|  | 238 | +            raise LLMGenerationError(e) | 
|  | 239 | + | 
|  | 240 | +    @async_rate_limit_handler | 
|  | 241 | +    async def ainvoke_with_tools( | 
|  | 242 | +        self, | 
|  | 243 | +        input: str, | 
|  | 244 | +        tools: Sequence[Tool],  # Tools definition as a sequence of Tool objects | 
|  | 245 | +        message_history: Optional[Union[List[LLMMessage], MessageHistory]] = None, | 
|  | 246 | +        system_instruction: Optional[str] = None, | 
|  | 247 | +    ) -> ToolCallResponse: | 
|  | 248 | +        """Sends a text input to the LLM with tool definitions | 
|  | 249 | +        and retrieves a tool call response. | 
|  | 250 | +
 | 
|  | 251 | +        Args: | 
|  | 252 | +            input (str): Text sent to the LLM. | 
|  | 253 | +            tools (List[Tool]): List of Tools for the LLM to choose from. | 
|  | 254 | +            message_history (Optional[Union[List[LLMMessage], MessageHistory]]): A collection previous messages, | 
|  | 255 | +                with each message having a specific role assigned. | 
|  | 256 | +            system_instruction (Optional[str]): An option to override the llm system message for this invocation. | 
|  | 257 | +
 | 
|  | 258 | +        Returns: | 
|  | 259 | +            ToolCallResponse: The response from the LLM containing a tool call. | 
|  | 260 | +
 | 
|  | 261 | +        Raises: | 
|  | 262 | +            LLMGenerationError: If anything goes wrong. | 
|  | 263 | +        """ | 
|  | 264 | +        try: | 
|  | 265 | +            if isinstance(message_history, MessageHistory): | 
|  | 266 | +                message_history = message_history.messages | 
|  | 267 | + | 
|  | 268 | +            # Convert tools to Ollama's expected type | 
|  | 269 | +            ollama_tools = [] | 
|  | 270 | +            for tool in tools: | 
|  | 271 | +                ollama_tool_format = self._convert_tool_to_ollama_format(tool) | 
|  | 272 | +                ollama_tools.append(ollama_tool_format) | 
|  | 273 | + | 
|  | 274 | +            response = await self.async_client.chat( | 
|  | 275 | +                model=self.model_name, | 
|  | 276 | +                messages=self.get_messages(input, message_history, system_instruction), | 
|  | 277 | +                tools=ollama_tools, | 
|  | 278 | +                **self.model_params, | 
|  | 279 | +            ) | 
|  | 280 | +            message = response.message | 
|  | 281 | + | 
|  | 282 | +            # If there's no tool call, return the content as a regular response | 
|  | 283 | +            if not message.tool_calls or len(message.tool_calls) == 0: | 
|  | 284 | +                return ToolCallResponse( | 
|  | 285 | +                    tool_calls=[], | 
|  | 286 | +                    content=message.content, | 
|  | 287 | +                ) | 
|  | 288 | + | 
|  | 289 | +            # Process all tool calls | 
|  | 290 | +            tool_calls = [] | 
|  | 291 | + | 
|  | 292 | +            for tool_call in message.tool_calls: | 
|  | 293 | +                args = tool_call.function.arguments | 
|  | 294 | +                tool_calls.append( | 
|  | 295 | +                    ToolCall(name=tool_call.function.name, arguments=args) | 
|  | 296 | +                ) | 
|  | 297 | + | 
|  | 298 | +            return ToolCallResponse(tool_calls=tool_calls, content=message.content) | 
|  | 299 | +        except self.ollama.ResponseError as e: | 
|  | 300 | +            raise LLMGenerationError(e) | 
|  | 301 | + | 
|  | 302 | +    def _convert_tool_to_ollama_format(self, tool: Tool) -> Dict[str, Any]: | 
|  | 303 | +        """Convert a Tool object to Ollama's expected format. | 
|  | 304 | +
 | 
|  | 305 | +        Args: | 
|  | 306 | +            tool: A Tool object to convert to Ollama's format. | 
|  | 307 | +
 | 
|  | 308 | +        Returns: | 
|  | 309 | +            A dictionary in Ollama's tool format. | 
|  | 310 | +        """ | 
|  | 311 | +        try: | 
|  | 312 | +            return { | 
|  | 313 | +                "type": "function", | 
|  | 314 | +                "function": { | 
|  | 315 | +                    "name": tool.get_name(), | 
|  | 316 | +                    "description": tool.get_description(), | 
|  | 317 | +                    "parameters": tool.get_parameters(), | 
|  | 318 | +                }, | 
|  | 319 | +            } | 
|  | 320 | +        except AttributeError: | 
|  | 321 | +            raise LLMGenerationError(f"Tool {tool} is not a valid Tool object") | 
0 commit comments