Skip to content

Commit f7cfb4e

Browse files
Added streaming support for BAMLAdapter
1 parent 1eccc38 commit f7cfb4e

File tree

4 files changed

+89
-6
lines changed

4 files changed

+89
-6
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dspy.evaluate import Evaluate # isort: skip
88
from dspy.clients import * # isort: skip
9-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip
9+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, BAMLAdapter, TwoStepAdapter, Image, Audio, History, Type, Tool, ToolCalls, Code # isort: skip
1010
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1111
from dspy.utils.asyncify import asyncify
1212
from dspy.utils.syncify import syncify

dspy/adapters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dspy.adapters.base import Adapter
2+
from dspy.adapters.baml_adapter import BAMLAdapter
23
from dspy.adapters.chat_adapter import ChatAdapter
34
from dspy.adapters.json_adapter import JSONAdapter
45
from dspy.adapters.two_step_adapter import TwoStepAdapter
@@ -7,6 +8,7 @@
78

89
__all__ = [
910
"Adapter",
11+
"BAMLAdapter",
1012
"ChatAdapter",
1113
"Type",
1214
"History",

dspy/streaming/streaming_listener.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from litellm import ModelResponseStream
77

8+
from dspy.adapters.baml_adapter import BAMLAdapter
89
from dspy.adapters.chat_adapter import ChatAdapter
910
from dspy.adapters.json_adapter import JSONAdapter
1011
from dspy.adapters.types import Type
@@ -15,7 +16,7 @@
1516
if TYPE_CHECKING:
1617
from dspy.primitives.module import Module
1718

18-
ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter]
19+
ADAPTER_SUPPORT_STREAMING = [ChatAdapter, XMLAdapter, JSONAdapter, BAMLAdapter]
1920

2021

2122
class StreamListener:
@@ -65,6 +66,11 @@ def __init__(
6566
"end_identifier": re.compile(rf"</{self.signature_field_name}>"),
6667
"start_indicator": "<",
6768
},
69+
"BAMLAdapter": {
70+
"start_identifier": f'"{self.signature_field_name}":',
71+
"end_identifier": re.compile(r"\w*\"(,|\s*})"),
72+
"start_indicator": '"',
73+
},
6874
}
6975

7076
def _buffered_message_end_with_start_identifier(self, concat_message: str, start_identifier: str) -> str:
@@ -145,8 +151,8 @@ def receive(self, chunk: ModelResponseStream):
145151
# Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
146152
value_start_index = concat_message.find(start_identifier) + len(start_identifier)
147153
chunk_message = concat_message[value_start_index:].lstrip()
148-
if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
149-
# For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
154+
if isinstance(settings.adapter, (JSONAdapter, BAMLAdapter)) and chunk_message.startswith('"'):
155+
# For JSONAdapter and BAMLAdapter, we need to remove the leading ". We cannot do this with the start_identifier
150156
# because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
151157
chunk_message = chunk_message[1:]
152158

@@ -194,7 +200,7 @@ def flush(self) -> str:
194200
"""
195201
last_tokens = "".join(self.field_end_queue.queue)
196202
self.field_end_queue = Queue()
197-
if isinstance(settings.adapter, JSONAdapter):
203+
if isinstance(settings.adapter, (JSONAdapter, BAMLAdapter)):
198204
match = re.search(r'",|"\s*}', last_tokens)
199205
if match:
200206
boundary_index = match.start()
@@ -206,7 +212,7 @@ def flush(self) -> str:
206212
if boundary_index == -1:
207213
boundary_index = len(last_tokens)
208214
return last_tokens[:boundary_index]
209-
elif isinstance(settings.adapter, ChatAdapter) or settings.adapter is None:
215+
elif isinstance(settings.adapter, (ChatAdapter, BAMLAdapter)) or settings.adapter is None:
210216
boundary_index = last_tokens.find("[[")
211217
return last_tokens[:boundary_index]
212218
else:

tests/streaming/test_streaming.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,81 @@ async def completion_side_effect(*args, **kwargs):
851851
assert all_chunks[1].chunk == "The answer is humorous."
852852

853853

854+
@pytest.mark.anyio
855+
async def test_stream_listener_returns_correct_chunk_baml_adapter():
856+
class MyProgram(dspy.Module):
857+
def __init__(self):
858+
super().__init__()
859+
self.predict1 = dspy.Predict("question->answer")
860+
self.predict2 = dspy.Predict("question,answer->judgement")
861+
862+
def forward(self, question, **kwargs):
863+
answer = self.predict1(question=question, **kwargs).answer
864+
judgement = self.predict2(question=question, answer=answer, **kwargs)
865+
return judgement
866+
867+
async def baml_stream_1(*args, **kwargs):
868+
# BAML uses JSON format for responses but ChatAdapter-style field delimiters in prompts
869+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))])
870+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="answer"))])
871+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":'))])
872+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="To"))])
873+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" get"))])
874+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" to"))])
875+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" the"))])
876+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" other"))])
877+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" side"))])
878+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="!"))])
879+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"'))])
880+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}\n"))])
881+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
882+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
883+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
884+
885+
async def baml_stream_2(*args, **kwargs):
886+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='{"'))])
887+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="judgement"))])
888+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='":"'))])
889+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="The"))])
890+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" answer"))])
891+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" is"))])
892+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content=" humorous"))])
893+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="."))])
894+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content='"'))])
895+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="}"))])
896+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
897+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
898+
yield ModelResponseStream(model="gpt-4o-mini", choices=[StreamingChoices(delta=Delta(content="None"))])
899+
900+
stream_generators = [baml_stream_1, baml_stream_2]
901+
902+
async def completion_side_effect(*args, **kwargs):
903+
return stream_generators.pop(0)()
904+
905+
with mock.patch("litellm.acompletion", side_effect=completion_side_effect):
906+
program = dspy.streamify(
907+
MyProgram(),
908+
stream_listeners=[
909+
dspy.streaming.StreamListener(signature_field_name="answer"),
910+
dspy.streaming.StreamListener(signature_field_name="judgement"),
911+
],
912+
)
913+
with dspy.context(lm=dspy.LM("openai/gpt-4o-mini", cache=False), adapter=dspy.BAMLAdapter()):
914+
output = program(question="why did a chicken cross the kitchen?")
915+
all_chunks = []
916+
async for value in output:
917+
if isinstance(value, dspy.streaming.StreamResponse):
918+
all_chunks.append(value)
919+
920+
assert all_chunks[0].predict_name == "predict1"
921+
assert all_chunks[0].signature_field_name == "answer"
922+
assert all_chunks[0].chunk == "To get to the other side!"
923+
924+
assert all_chunks[1].predict_name == "predict2"
925+
assert all_chunks[1].signature_field_name == "judgement"
926+
assert all_chunks[1].chunk == "The answer is humorous."
927+
928+
854929
@pytest.mark.anyio
855930
async def test_streaming_allows_custom_chunk_types():
856931
@dataclass

0 commit comments

Comments
 (0)