Skip to content

Commit 083d04d

Browse files
Fix chunk loss in the long streaming response with native response field (#8881)
* Fix chunk loss in the long streaming response * minor validation * comment * do not hit buffer yield block when chunk_message is empt * fix condition check --------- Co-authored-by: chenmoneygithub <[email protected]>
1 parent 7f822be commit 083d04d

File tree

3 files changed

+33
-15
lines changed

3 files changed

+33
-15
lines changed

dspy/streaming/streamify.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,11 @@ async def async_streamer(*args, **kwargs):
185185
else:
186186
# We are receiving a chunk from the LM's response stream, delegate it to the listeners to
187187
# determine if we should yield a value to the user.
188-
output = None
189188
for listener in predict_id_to_listener[value.predict_id]:
190-
# There should be at most one listener provides a return value.
191-
output = listener.receive(value) or output
192-
if output:
193-
yield output
189+
# In some special cases such as Citation API, it is possible that multiple listeners
190+
# return values at the same time due to the chunk buffer of the listener.
191+
if output := listener.receive(value):
192+
yield output
194193
elif isinstance(value, StatusMessage):
195194
yield value
196195
elif isinstance(value, Prediction):

dspy/streaming/streaming_listener.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def receive(self, chunk: ModelResponseStream):
159159
self.field_start_queue = []
160160
return
161161

162-
if self.stream_start:
162+
if self.stream_start and chunk_message:
163163
# The stream is started, we keep returning the token until we see the start of the next field.
164164
token = None
165165
self.field_end_queue.put(chunk_message)
@@ -168,6 +168,7 @@ def receive(self, chunk: ModelResponseStream):
168168
# i.e., "[[ ## {next_field_name} ## ]]" for ChatAdapter to identify the end of the current field.
169169
# In most cases 10 tokens are enough to cover the end_identifier for all adapters.
170170
token = self.field_end_queue.get()
171+
171172
concat_message = "".join(self.field_end_queue.queue).strip()
172173
if re.search(end_identifier, concat_message):
173174
# The next field is identified, we can end the stream and flush out all tokens in the buffer.

tests/streaming/test_streaming.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -949,29 +949,39 @@ def forward(self, documents, question, **kwargs):
949949

950950
async def citation_stream(*args, **kwargs):
951951
# Stream chunks with citation data in provider_specific_fields
952+
# To verify the realistic scenario with more than 10 chunks in the stream, include more than 10 chunks before the citation.
952953
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
953954
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" answer"))])
954955
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" ## ]]\n\n"))])
955-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="Water"))])
956-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" boils"))])
957-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" at"))])
958-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" 100°C"))])
959-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="."))])
960-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="\n\n"))])
961-
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content='[{"type": "char_location", "cited_text": "Water boils at 100°C", "document_index": 0, "document_title": "Physics Facts", "start_char_index": 0, "end_char_index": 19}]'))])
956+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="A"))])
957+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="c"))])
958+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="c"))])
959+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="o"))])
960+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="r"))])
961+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="d"))])
962+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="i"))])
963+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="n"))])
964+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="g"))])
965+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" to "))])
966+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="the references,"))])
962967
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(
963968
content="",
964969
provider_specific_fields={
965970
"citation": {
966971
"type": "char_location",
967-
"cited_text": "Water boils at 100°C",
972+
"cited_text": "water boils at 100°C",
968973
"document_index": 0,
969974
"document_title": "Physics Facts",
970975
"start_char_index": 0,
971976
"end_char_index": 19
972977
}
973978
}
974979
))])
980+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" water"))])
981+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" boils"))])
982+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" at"))])
983+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" 100°C"))])
984+
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="."))])
975985
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="\n\n"))])
976986
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content="[[ ##"))])
977987
yield ModelResponseStream(model="claude", choices=[StreamingChoices(delta=Delta(content=" completed"))])
@@ -982,6 +992,7 @@ async def citation_stream(*args, **kwargs):
982992
program = dspy.streamify(
983993
MyProgram(),
984994
stream_listeners=[
995+
dspy.streaming.StreamListener(signature_field_name="answer"),
985996
dspy.streaming.StreamListener(signature_field_name="citations"),
986997
],
987998
)
@@ -992,10 +1003,13 @@ async def citation_stream(*args, **kwargs):
9921003
with dspy.context(lm=dspy.LM("anthropic/claude-3-5-sonnet-20241022", cache=False), adapter=dspy.ChatAdapter(native_response_types=[Citations])):
9931004
output = program(documents=docs, question="What temperature does water boil?")
9941005
citation_chunks = []
1006+
answer_chunks = []
9951007
final_prediction = None
9961008
async for value in output:
9971009
if isinstance(value, dspy.streaming.StreamResponse) and value.signature_field_name == "citations":
9981010
citation_chunks.append(value)
1011+
elif isinstance(value, dspy.streaming.StreamResponse) and value.signature_field_name == "answer":
1012+
answer_chunks.append(value.chunk)
9991013
elif isinstance(value, dspy.Prediction):
10001014
final_prediction = value
10011015

@@ -1004,10 +1018,14 @@ async def citation_stream(*args, **kwargs):
10041018
citation_chunk = citation_chunks[0]
10051019
assert isinstance(citation_chunk.chunk, Citations)
10061020
assert len(citation_chunk.chunk) == 1
1007-
assert citation_chunk.chunk[0].cited_text == "Water boils at 100°C"
1021+
assert citation_chunk.chunk[0].cited_text == "water boils at 100°C"
10081022
assert citation_chunk.chunk[0].document_title == "Physics Facts"
10091023

1024+
# Verify the answer chunks are correct
1025+
assert "".join(answer_chunks) == "According to the references, water boils at 100°C."
1026+
10101027
# Test that prediction contains the expected fields
10111028
assert final_prediction is not None
10121029
assert hasattr(final_prediction, "answer")
10131030
assert hasattr(final_prediction, "citations")
1031+
assert final_prediction.answer == "According to the references, water boils at 100°C."

0 commit comments

Comments
 (0)