@@ -52,7 +52,7 @@ class EventType(Enum):
52
52
retry : Optional [int ]
53
53
54
54
def __str__ (self ) -> str :
55
- if self .event == "output" :
55
+ if self .event == ServerSentEvent . EventType . OUTPUT :
56
56
return self .data
57
57
58
58
return ""
@@ -114,7 +114,7 @@ def decode(self, line: str) -> Optional[ServerSentEvent]:
114
114
return None
115
115
116
116
fieldname , _ , value = line .partition (":" )
117
- value = value .lstrip ( )
117
+ value = value .removeprefix ( " " )
118
118
119
119
if fieldname == "event" :
120
120
if event := ServerSentEvent .EventType (value ):
@@ -138,26 +138,28 @@ def __iter__(self) -> Iterator[ServerSentEvent]:
138
138
line = line .rstrip ("\n " )
139
139
sse = decoder .decode (line )
140
140
if sse is not None :
141
- if sse .event == "done" :
142
- return
143
- if sse .event == "error" :
141
+ if sse .event == ServerSentEvent .EventType .ERROR :
144
142
raise RuntimeError (sse .data )
145
143
146
144
yield sse
147
145
146
+ if sse .event == ServerSentEvent .EventType .DONE :
147
+ return
148
+
148
149
async def __aiter__ (self ) -> AsyncIterator [ServerSentEvent ]:
149
150
decoder = EventSource .Decoder ()
150
151
async for line in self .response .aiter_lines ():
151
152
line = line .rstrip ("\n " )
152
153
sse = decoder .decode (line )
153
154
if sse is not None :
154
- if sse .event == "done" :
155
- return
156
- if sse .event == "error" :
155
+ if sse .event == ServerSentEvent .EventType .ERROR :
157
156
raise RuntimeError (sse .data )
158
157
159
158
yield sse
160
159
160
+ if sse .event == ServerSentEvent .EventType .DONE :
161
+ return
162
+
161
163
162
164
def stream (
163
165
client : "Client" ,
0 commit comments