Skip to content

Commit 7f49900

Browse files
authored
Fix stream functionality (#214)
Resolves #213 This PR fixes a few bugs in the streaming implementation added with #204. --------- Signed-off-by: Mattt Zmuda <[email protected]>
1 parent b478f16 commit 7f49900

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

README.md

+1-5
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,13 @@ import replicate
9090
# https://replicate.com/meta/llama-2-70b-chat
9191
model_version = "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
9292
93-
tokens = []
9493
for event in replicate.stream(
9594
model_version,
9695
input={
9796
"prompt": "Please write a haiku about llamas.",
9897
},
9998
):
100-
print(event)
101-
tokens.append(str(event))
102-
103-
print("".join(tokens))
99+
print(str(event), end="")
104100
```
105101
106102
For more information, see

replicate/stream.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class EventType(Enum):
5252
retry: Optional[int]
5353

5454
def __str__(self) -> str:
55-
if self.event == "output":
55+
if self.event == ServerSentEvent.EventType.OUTPUT:
5656
return self.data
5757

5858
return ""
@@ -114,7 +114,7 @@ def decode(self, line: str) -> Optional[ServerSentEvent]:
114114
return None
115115

116116
fieldname, _, value = line.partition(":")
117-
value = value.lstrip()
117+
value = value.removeprefix(" ")
118118

119119
if fieldname == "event":
120120
if event := ServerSentEvent.EventType(value):
@@ -138,26 +138,28 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
138138
line = line.rstrip("\n")
139139
sse = decoder.decode(line)
140140
if sse is not None:
141-
if sse.event == "done":
142-
return
143-
if sse.event == "error":
141+
if sse.event == ServerSentEvent.EventType.ERROR:
144142
raise RuntimeError(sse.data)
145143

146144
yield sse
147145

146+
if sse.event == ServerSentEvent.EventType.DONE:
147+
return
148+
148149
async def __aiter__(self) -> AsyncIterator[ServerSentEvent]:
149150
decoder = EventSource.Decoder()
150151
async for line in self.response.aiter_lines():
151152
line = line.rstrip("\n")
152153
sse = decoder.decode(line)
153154
if sse is not None:
154-
if sse.event == "done":
155-
return
156-
if sse.event == "error":
155+
if sse.event == ServerSentEvent.EventType.ERROR:
157156
raise RuntimeError(sse.data)
158157

159158
yield sse
160159

160+
if sse.event == ServerSentEvent.EventType.DONE:
161+
return
162+
161163

162164
def stream(
163165
client: "Client",

0 commit comments

Comments
 (0)