From 05e44d7c44c1e1381f13425d5f80c98abaa3d1c1 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 18 Sep 2025 11:59:29 +0900 Subject: [PATCH 1/3] Refactor ToolCalls handling to use generic base type function --- dspy/adapters/base.py | 21 +++++---------------- dspy/adapters/two_step_adapter.py | 15 --------------- dspy/adapters/types/tool.py | 28 +++++++++++++++++++++++++++- 3 files changed, 32 insertions(+), 32 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 701f84ccef..3bf0fdbc29 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -1,7 +1,6 @@ import logging from typing import TYPE_CHECKING, Any, get_origin -import json_repair import litellm from dspy.adapters.types import History, Type @@ -16,7 +15,7 @@ if TYPE_CHECKING: from dspy.clients.lm import LM -_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations] +_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations, ToolCalls] class Adapter: def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False, native_response_types: list[type[Type]] | None = None): @@ -82,17 +81,13 @@ def _call_postprocess( ) -> list[dict[str, Any]]: values = [] - tool_call_output_field_name = self._get_tool_call_output_field_name(original_signature) - for output in outputs: output_logprobs = None - tool_calls = None text = output if isinstance(output, dict): text = output["text"] output_logprobs = output.get("logprobs") - tool_calls = output.get("tool_calls") if text: value = self.parse(processed_signature, text) @@ -105,20 +100,14 @@ def _call_postprocess( for field_name in original_signature.output_fields.keys(): value[field_name] = None - if tool_calls and tool_call_output_field_name: - tool_calls = [ - { - "name": v["function"]["name"], - "args": json_repair.loads(v["function"]["arguments"]), - } - for v in tool_calls - ] - value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) # Parse custom types that does not rely on the adapter parsing for name, field in original_signature.output_fields.items(): if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types: - value[name] = field.annotation.parse_lm_response(output) + if name not in value: + parsed_value = field.annotation.parse_lm_response(output) + if parsed_value is not None: + value[name] = parsed_value if output_logprobs: value["logprobs"] = output_logprobs diff --git a/dspy/adapters/two_step_adapter.py b/dspy/adapters/two_step_adapter.py index b78375f140..b2361be3df 100644 --- a/dspy/adapters/two_step_adapter.py +++ b/dspy/adapters/two_step_adapter.py @@ -1,10 +1,7 @@ from typing import Any -import json_repair - from dspy.adapters.base import Adapter from dspy.adapters.chat_adapter import ChatAdapter -from dspy.adapters.types import ToolCalls from dspy.adapters.utils import get_field_description_string from dspy.clients import LM from dspy.signatures.field import InputField @@ -119,16 +116,13 @@ async def acall( values = [] - tool_call_output_field_name = self._get_tool_call_output_field_name(signature) for output in outputs: output_logprobs = None - tool_calls = None text = output if isinstance(output, dict): text = output["text"] output_logprobs = output.get("logprobs") - tool_calls = output.get("tool_calls") try: # Call the smaller LM to extract structured data from the raw completion text with ChatAdapter @@ -144,15 +138,6 @@ async def acall( except Exception as e: raise ValueError(f"Failed to parse response from the original completion: {output}") from e - if tool_calls and tool_call_output_field_name: - tool_calls = [ - { - "name": v["function"]["name"], - "args": json_repair.loads(v["function"]["arguments"]), - } - for v in tool_calls - ] - value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) if output_logprobs is not None: value["logprobs"] = output_logprobs diff --git a/dspy/adapters/types/tool.py b/dspy/adapters/types/tool.py index 5cde703a41..6ec73e8a0d 100644 --- a/dspy/adapters/types/tool.py +++ b/dspy/adapters/types/tool.py @@ -1,6 +1,6 @@ import asyncio import inspect -from typing import TYPE_CHECKING, Any, Callable, Type, get_origin, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, Optional, get_origin, get_type_hints import pydantic from jsonschema import ValidationError, validate @@ -307,6 +307,32 @@ def format(self) -> list[dict[str, Any]]: "tool_calls": [tool_call.format() for tool_call in self.tool_calls], } + @classmethod + def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["ToolCalls"]: + """Parse a LM response into ToolCalls. + + Args: + response: A LM response that may contain tool call data. + + Returns: + A ToolCalls object if tool call data is found, None otherwise. + """ + import json_repair + + if isinstance(response, dict) and "tool_calls" in response: + tool_calls_data = response["tool_calls"] + if isinstance(tool_calls_data, list): + tool_calls = [ + { + "name": v["function"]["name"], + "args": json_repair.loads(v["function"]["arguments"]), + } + for v in tool_calls_data + ] + return ToolCalls.from_dict_list(tool_calls) + + return None + @pydantic.model_validator(mode="before") @classmethod def validate_input(cls, data: Any): From 2f06f90f04d9e356a45cca62bf277940e685f77e Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 18 Sep 2025 12:08:03 +0900 Subject: [PATCH 2/3] comment --- dspy/adapters/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 3bf0fdbc29..2aebee838d 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -104,7 +104,7 @@ def _call_postprocess( # Parse custom types that does not rely on the adapter parsing for name, field in original_signature.output_fields.items(): if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types: - if name not in value: + if value.get(name) is None: parsed_value = field.annotation.parse_lm_response(output) if parsed_value is not None: value[name] = parsed_value From 136b5eae4d1dd6b74d8104d8d08d36f8ec687e96 Mon Sep 17 00:00:00 2001 From: TomuHirata Date: Thu, 18 Sep 2025 12:17:06 +0900 Subject: [PATCH 3/3] fix test --- dspy/adapters/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 2aebee838d..d7ce3cfa9a 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -15,13 +15,15 @@ if TYPE_CHECKING: from dspy.clients.lm import LM -_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations, ToolCalls] +_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations] class Adapter: def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False, native_response_types: list[type[Type]] | None = None): self.callbacks = callbacks or [] self.use_native_function_calling = use_native_function_calling - self.native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES + self.native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES.copy() + if self.use_native_function_calling: + self.native_response_types.append(ToolCalls) def __init_subclass__(cls, **kwargs) -> None: super().__init_subclass__(**kwargs)