Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions dspy/adapters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dspy.adapters.types import History, Type
from dspy.adapters.types.base_type import split_message_content_for_custom_types
from dspy.adapters.types.tool import Tool, ToolCalls
from dspy.adapters.utils import LargePayloadHashManager
from dspy.experimental import Citations
from dspy.signatures.signature import Signature
from dspy.utils.callback import BaseCallback, with_callbacks
Expand All @@ -18,8 +19,14 @@

_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations]


class Adapter:
def __init__(self, callbacks: list[BaseCallback] | None = None, use_native_function_calling: bool = False, native_response_types: list[type[Type]] | None = None):
def __init__(
self,
callbacks: list[BaseCallback] | None = None,
use_native_function_calling: bool = False,
native_response_types: list[type[Type]] | None = None,
):
self.callbacks = callbacks or []
self.use_native_function_calling = use_native_function_calling
self.native_response_types = native_response_types or _DEFAULT_NATIVE_RESPONSE_TYPES
Expand Down Expand Up @@ -68,7 +75,11 @@ def _call_preprocess(

# Handle custom types that use native response
for name, field in signature.output_fields.items():
if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types:
if (
isinstance(field.annotation, type)
and issubclass(field.annotation, Type)
and field.annotation in self.native_response_types
):
signature = signature.delete(name)

return signature
Expand Down Expand Up @@ -117,7 +128,11 @@ def _call_postprocess(

# Parse custom types that does not rely on the adapter parsing
for name, field in original_signature.output_fields.items():
if isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types:
if (
isinstance(field.annotation, type)
and issubclass(field.annotation, Type)
and field.annotation in self.native_response_types
):
value[name] = field.annotation.parse_lm_response(output)

if output_logprobs:
Expand Down Expand Up @@ -200,7 +215,14 @@ def format(
Returns:
A list of multiturn messages as expected by the LM.
"""
inputs_copy = dict(inputs)

# Replace large data with hashes, e.g. image base64 data (speeds up string formatting). Will be restored later.
data_manager = LargePayloadHashManager()
inputs_hashed = data_manager.replace_large_data(inputs)
demos_hashed = data_manager.replace_large_data(demos)

# Work on a shallow copy of inputs
inputs_copy = dict(inputs_hashed)

# If the signature and inputs have conversation history, we need to format the conversation history and
# remove the history field from the signature.
Expand All @@ -221,7 +243,7 @@ def format(
f"{self.format_task_description(signature)}"
)
messages.append({"role": "system", "content": system_message})
messages.extend(self.format_demos(signature, demos))
messages.extend(self.format_demos(signature, demos_hashed))
if history_field_name:
# Conversation history and current input
content = self.format_user_message_content(signature_without_history, inputs_copy, main_request=True)
Expand All @@ -233,6 +255,8 @@ def format(
messages.append({"role": "user", "content": content})

messages = split_message_content_for_custom_types(messages)
messages = data_manager.restore_large_data(messages)

return messages

def format_field_description(self, signature: type[Signature]) -> str:
Expand Down Expand Up @@ -404,7 +428,6 @@ def _get_tool_call_output_field_name(self, signature: type[Signature]) -> bool:
return name
return None


def format_conversation_history(
self,
signature: type[Signature],
Expand Down
3 changes: 3 additions & 0 deletions dspy/adapters/types/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,9 @@ def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_imag
else:
# Return the URL as is
return image
elif image.startswith("__DSPY_LARGE_DATA_HASH_"):
# DSPy large data hash identifier - return as-is during optimization
return image
else:
# Unsupported string format
raise ValueError(f"Unrecognized file string: {image}; If this file type should be supported, please open an issue.")
Expand Down
129 changes: 129 additions & 0 deletions dspy/adapters/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import enum
import inspect
import json
Expand All @@ -12,6 +13,7 @@
from pydantic.fields import FieldInfo

from dspy.adapters.types.base_type import Type as DspyType
from dspy.adapters.types.history import History
from dspy.signatures.utils import get_dspy_field_type


Expand Down Expand Up @@ -280,3 +282,130 @@ def _quoted_string_for_literal_type_annotation(s: str) -> str:
else:
# Neither => enclose in single quotes
return f"'{s}'"


class LargePayloadHashManager:
"""
Used by `format` in adapters.base.py
That function formats the input prompt as one string, which means it stringifies
large data (dspy types like Image, Audio). Then it splits that prompt string into
a list, where special types get their own item. But stringifying large data is slow
and memory-intensive. Instead, we replace the large data with a hash token before
building the prompt, and restore the original data at the end.

This class facilitates that with these two methods:
- replace_large_data: replace large data with hash tokens for inputs, demos, history
- restore_large_data: restore the original data from hash token for LLM messages

Notes:
- Non‑destructive: never mutates inputs; returns new structures mirroring shape.
- Threshold-based: strings longer than `LARGE_DATA_THRESHOLD` are hashed
- Only Image type is implemented - TODO for audio, etc.
"""

# Configurable threshold for what constitutes "large" data
LARGE_DATA_THRESHOLD = 1000 # characters

def __init__(self):
"""Initialize the manager with fresh hash mappings."""
self.hash_to_data = {} # hash_id -> original_data
self.data_to_hash = {} # original_data -> hash_id
self.hash_counter = 0

def replace_large_data(self, obj: Any) -> Any:
"""Return a copy of `obj` with large string fields replaced by hash tokens."""
if isinstance(obj, DspyType):
# Handle DSPy custom types: Images
return self._replace_in_custom_type(obj)
elif hasattr(obj, "items") and not isinstance(obj, dict):
new_mapping: dict[str, Any] = {}
for k, v in obj.items():
new_mapping[k] = self.replace_large_data(v)
return new_mapping
elif isinstance(obj, dict):
new_dict = {}
for k, v in obj.items():
new_dict[k] = self.replace_large_data(v)
return new_dict
elif isinstance(obj, list):
new_list = []
for item in obj:
new_list.append(self.replace_large_data(item))
return new_list
else:
return obj

def restore_large_data(self, obj: Any) -> Any:
"""Return a copy of `obj` with hash tokens restored to original data."""
if isinstance(obj, dict):
if self._is_message_content_with_large_data(obj):
return self._restore_in_message_content(obj)
else:
return {k: self.restore_large_data(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.restore_large_data(item) for item in obj]
elif isinstance(obj, str):
if self._is_hash_identifier(obj):
return self.hash_to_data.get(obj, obj)
else:
return obj
else:
return obj

def _create_hash_for_data(self, data: str) -> str:
"""Create a unique hash identifier for large data."""
# Check if we've already hashed this exact data (deduplication)
if data in self.data_to_hash:
return self.data_to_hash[data]

# Create new hash identifier
hash_id = f"__DSPY_LARGE_DATA_HASH_{self.hash_counter}__"
self.hash_counter += 1

# Store bidirectional mapping
self.hash_to_data[hash_id] = data
self.data_to_hash[data] = hash_id

return hash_id

def _is_large_data(self, data: Any) -> bool:
"""Check if data is large enough to warrant optimization."""
if isinstance(data, str):
return len(data) > self.LARGE_DATA_THRESHOLD
return False

def _replace_in_custom_type(self, obj: DspyType) -> DspyType:
"""For Image type, replace large payload fields in known custom types; otherwise return `obj`."""
if isinstance(obj, History):
new_messages = self.replace_large_data(obj.messages)
if new_messages is not obj.messages: # Messages changed
return type(obj)(messages=new_messages)
else:
return obj

elif hasattr(obj, "url") and self._is_large_data(obj.url):
# TODO: it only handles image url, not other types yet
hash_url = self._create_hash_for_data(obj.url)
return type(obj)(url=hash_url)

return obj

def _is_message_content_with_large_data(self, obj: dict) -> bool:
"""Whether this dict is an image content block with a payload field."""
# {"type": "image_url", "image_url": {"url": "hash_id"}}
return obj.get("type") == "image_url" and "image_url" in obj and "url" in obj["image_url"]

def _restore_in_message_content(self, obj: dict) -> dict:
"""Restore payloads inside image/audio message content blocks."""
obj_copy = copy.deepcopy(obj)

if obj.get("type") == "image_url" and "image_url" in obj:
url = obj["image_url"].get("url", "")
if self._is_hash_identifier(url):
obj_copy["image_url"]["url"] = self.hash_to_data.get(url, url)

return obj_copy

def _is_hash_identifier(self, text: str) -> bool:
"""True if `text` looks like a DSPy large-data hash token."""
return isinstance(text, str) and text.startswith("__DSPY_LARGE_DATA_HASH_")