-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Introduce dspy.Reasoning to capture native reasoning from reasoning models #8986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6af5b75
c699a1f
beb85de
4c5b633
5228863
3210914
d5b0dfb
b2daf8f
3cff43a
3258da5
8de0a65
ec2fbe4
56973f0
c65b774
8c1630c
93991f5
67eda2c
b7b4dcf
d810943
1e4ebe2
6801afa
417737a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,118 @@ | ||
| from typing import TYPE_CHECKING, Any, Optional | ||
|
|
||
| import litellm | ||
| import pydantic | ||
|
|
||
| from dspy.adapters.types.base_type import Type | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dspy.clients.lm import LM | ||
| from dspy.signatures.signature import Signature | ||
|
|
||
|
|
||
| class Reasoning(Type): | ||
| """Reasoning type in DSPy. | ||
|
|
||
| This type is useful when you want the DSPy output to include the reasoning of the LM. We build this type so that | ||
| DSPy can support the reasoning model and non-reasoning model with the same code. | ||
|
|
||
| This is a str-like type, you can convert a string directly to a Reasoning object, and from DSPy adapters' | ||
| perspective, `Reasoning` is treated as a string. | ||
| """ | ||
|
|
||
| content: str | ||
|
|
||
| def format(self): | ||
| return f"{self.content}" | ||
|
|
||
| @pydantic.model_validator(mode="before") | ||
| @classmethod | ||
| def validate_input(cls, data: Any): | ||
| if isinstance(data, cls): | ||
| return data | ||
|
|
||
| if isinstance(data, str): | ||
| return {"content": data} | ||
|
|
||
| if isinstance(data, dict): | ||
| if "content" not in data: | ||
| raise ValueError("`content` field is required for `dspy.Reasoning`") | ||
| if not isinstance(data["content"], str): | ||
| raise ValueError(f"`content` field must be a string, but received type: {type(data['content'])}") | ||
| return {"content": data["content"]} | ||
|
|
||
| raise ValueError(f"Received invalid value for `dspy.Reasoning`: {data}") | ||
|
|
||
| @classmethod | ||
| def adapt_to_native_lm_feature( | ||
| cls, | ||
| signature: type["Signature"], | ||
| field_name: str, | ||
| lm: "LM", | ||
| lm_kwargs: dict[str, Any], | ||
| ) -> type["Signature"]: | ||
| if "reasoning_effort" in lm_kwargs: | ||
| # `lm_kwargs` overrides `lm.kwargs`. | ||
| reasoning_effort = lm_kwargs["reasoning_effort"] | ||
| elif "reasoning_effort" in lm.kwargs: | ||
| reasoning_effort = lm.kwargs["reasoning_effort"] | ||
| else: | ||
| # Turn on the native reasoning explicitly if Reasoning field is present in the signature and no explicit | ||
| # reasoning effort is set in `lm_kwargs` or `lm.kwargs`. | ||
| reasoning_effort = "low" | ||
|
|
||
| if reasoning_effort is None or not litellm.supports_reasoning(lm.model): | ||
| # If users explicitly set `reasoning_effort` to None or the LM doesn't support reasoning, we don't enable | ||
| # native reasoning. | ||
| return signature | ||
|
|
||
| if "gpt-5" in lm.model and lm.model_type == "chat": | ||
| # There is a caveat of Litellm as 1.79.0 that when using the chat completion API on GPT-5 family models, | ||
| # the reasoning content is not available in the response. As a workaround, we don't enable the native | ||
| # reasoning feature for GPT-5 family models when using the chat completion API. | ||
| # Litellm issue: https://github.com/BerriAI/litellm/issues/14748 | ||
| return signature | ||
|
|
||
| lm_kwargs["reasoning_effort"] = reasoning_effort | ||
| # Delete the reasoning field from the signature to use the native reasoning feature. | ||
| return signature.delete(field_name) | ||
|
|
||
| @classmethod | ||
| def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Reasoning"]: | ||
| """Parse the LM response into a Reasoning object.""" | ||
| if "reasoning_content" in response: | ||
| return Reasoning(content=response["reasoning_content"]) | ||
| return None | ||
|
|
||
| @classmethod | ||
| def parse_stream_chunk(cls, chunk) -> str | None: | ||
| """ | ||
| Parse a stream chunk into reasoning content if available. | ||
|
|
||
| Args: | ||
| chunk: A stream chunk from the LM. | ||
|
|
||
| Returns: | ||
| The reasoning content (str) if available, None otherwise. | ||
| """ | ||
| try: | ||
| if choices := getattr(chunk, "choices", None): | ||
| return getattr(choices[0].delta, "reasoning_content", None) | ||
| except Exception: | ||
| return None | ||
|
|
||
| @classmethod | ||
| def is_streamable(cls) -> bool: | ||
| return True | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"{self.content!r}" | ||
|
|
||
| def __str__(self) -> str: | ||
| return self.content | ||
|
|
||
| def __eq__(self, other: object) -> bool: | ||
| if isinstance(other, Reasoning): | ||
| return self.content == other.content | ||
| if isinstance(other, str): | ||
| return self.content == other | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,6 +12,7 @@ | |
| from pydantic.fields import FieldInfo | ||
|
|
||
| from dspy.adapters.types.base_type import Type as DspyType | ||
| from dspy.adapters.types.reasoning import Reasoning | ||
| from dspy.signatures.utils import get_dspy_field_type | ||
|
|
||
|
|
||
|
|
@@ -84,7 +85,7 @@ def move_type_to_front(d): | |
| def translate_field_type(field_name, field_info): | ||
| field_type = field_info.annotation | ||
|
|
||
| if get_dspy_field_type(field_info) == "input" or field_type is str: | ||
| if get_dspy_field_type(field_info) == "input" or field_type is str or field_type is Reasoning: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above, let me know your thought! |
||
| desc = "" | ||
| elif field_type is bool: | ||
| desc = "must be True or False" | ||
|
|
@@ -190,6 +191,9 @@ def get_annotation_name(annotation): | |
| origin = get_origin(annotation) | ||
| args = get_args(annotation) | ||
| if origin is None: | ||
| if annotation is Reasoning: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any way to implement the conversion more generically? Ideally this information should reside in Reasoning. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's a good question. I did think about the same thing, but changing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it, we need this conversion so that LLM won't return |
||
| # Reasoning field type is treated as a string. | ||
| return "str" | ||
| if hasattr(annotation, "__name__"): | ||
| return annotation.__name__ | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -130,13 +130,6 @@ def receive(self, chunk: ModelResponseStream): | |
| else: | ||
| return | ||
|
|
||
| try: | ||
| chunk_message = chunk.choices[0].delta.content | ||
| if chunk_message is None: | ||
| return | ||
| except Exception: | ||
| return | ||
|
|
||
| # Handle custom streamable types | ||
| if self._output_type and issubclass(self._output_type, Type) and self._output_type.is_streamable(): | ||
| if parsed_chunk := self._output_type.parse_stream_chunk(chunk): | ||
|
|
@@ -147,6 +140,14 @@ def receive(self, chunk: ModelResponseStream): | |
| is_last_chunk=self.stream_end, | ||
| ) | ||
|
|
||
| # For non-custom streamable types, the streaming chunks come from the content field of the ModelResponseStream. | ||
| try: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: maybe add comment why this logic should come after native response handling? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good call, done! |
||
| chunk_message = chunk.choices[0].delta.content | ||
| if chunk_message is None: | ||
| return | ||
| except Exception: | ||
| return | ||
|
|
||
| if chunk_message and start_identifier in chunk_message: | ||
| # If the cache is hit, the chunk_message could be the full response. When it happens we can | ||
| # directly end the stream listening. In some models like gemini, each stream chunk can be multiple | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any github issue for this on LiteLLM?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good call! added