@@ -949,29 +949,39 @@ def forward(self, documents, question, **kwargs):
949949
950950 async def citation_stream (* args , ** kwargs ):
951951 # Stream chunks with citation data in provider_specific_fields
952+ # To verify the realistic scenario with more than 10 chunks in the stream, include more than 10 chunks before the citation.
952953 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "[[ ##" ))])
953954 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " answer" ))])
954955 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " ## ]]\n \n " ))])
955- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "Water" ))])
956- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " boils" ))])
957- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " at" ))])
958- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " 100°C" ))])
959- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "." ))])
960- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "\n \n " ))])
961- yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = '[{"type": "char_location", "cited_text": "Water boils at 100°C", "document_index": 0, "document_title": "Physics Facts", "start_char_index": 0, "end_char_index": 19}]' ))])
956+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "A" ))])
957+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "c" ))])
958+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "c" ))])
959+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "o" ))])
960+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "r" ))])
961+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "d" ))])
962+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "i" ))])
963+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "n" ))])
964+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "g" ))])
965+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " to " ))])
966+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "the references," ))])
962967 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (
963968 content = "" ,
964969 provider_specific_fields = {
965970 "citation" : {
966971 "type" : "char_location" ,
967- "cited_text" : "Water boils at 100°C" ,
972+ "cited_text" : "water boils at 100°C" ,
968973 "document_index" : 0 ,
969974 "document_title" : "Physics Facts" ,
970975 "start_char_index" : 0 ,
971976 "end_char_index" : 19
972977 }
973978 }
974979 ))])
980+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " water" ))])
981+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " boils" ))])
982+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " at" ))])
983+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " 100°C" ))])
984+ yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "." ))])
975985 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "\n \n " ))])
976986 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = "[[ ##" ))])
977987 yield ModelResponseStream (model = "claude" , choices = [StreamingChoices (delta = Delta (content = " completed" ))])
@@ -982,6 +992,7 @@ async def citation_stream(*args, **kwargs):
982992 program = dspy .streamify (
983993 MyProgram (),
984994 stream_listeners = [
995+ dspy .streaming .StreamListener (signature_field_name = "answer" ),
985996 dspy .streaming .StreamListener (signature_field_name = "citations" ),
986997 ],
987998 )
@@ -992,10 +1003,13 @@ async def citation_stream(*args, **kwargs):
9921003 with dspy .context (lm = dspy .LM ("anthropic/claude-3-5-sonnet-20241022" , cache = False ), adapter = dspy .ChatAdapter (native_response_types = [Citations ])):
9931004 output = program (documents = docs , question = "What temperature does water boil?" )
9941005 citation_chunks = []
1006+ answer_chunks = []
9951007 final_prediction = None
9961008 async for value in output :
9971009 if isinstance (value , dspy .streaming .StreamResponse ) and value .signature_field_name == "citations" :
9981010 citation_chunks .append (value )
1011+ elif isinstance (value , dspy .streaming .StreamResponse ) and value .signature_field_name == "answer" :
1012+ answer_chunks .append (value .chunk )
9991013 elif isinstance (value , dspy .Prediction ):
10001014 final_prediction = value
10011015
@@ -1004,10 +1018,14 @@ async def citation_stream(*args, **kwargs):
10041018 citation_chunk = citation_chunks [0 ]
10051019 assert isinstance (citation_chunk .chunk , Citations )
10061020 assert len (citation_chunk .chunk ) == 1
1007- assert citation_chunk .chunk [0 ].cited_text == "Water boils at 100°C"
1021+ assert citation_chunk .chunk [0 ].cited_text == "water boils at 100°C"
10081022 assert citation_chunk .chunk [0 ].document_title == "Physics Facts"
10091023
1024+ # Verify the answer chunks are correct
1025+ assert "" .join (answer_chunks ) == "According to the references, water boils at 100°C."
1026+
10101027 # Test that prediction contains the expected fields
10111028 assert final_prediction is not None
10121029 assert hasattr (final_prediction , "answer" )
10131030 assert hasattr (final_prediction , "citations" )
1031+ assert final_prediction .answer == "According to the references, water boils at 100°C."
0 commit comments