Skip to content

Commit 80881b5

Browse files
fix: ensure streams are always closed
1 parent db08b3f commit 80881b5

File tree

1 file changed

+48
-46
lines changed

1 file changed

+48
-46
lines changed

src/gradient/_streaming.py

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -55,29 +55,30 @@ def __stream__(self) -> Iterator[_T]:
5555
process_data = self._client._process_response_data
5656
iterator = self._iter_events()
5757

58-
for sse in iterator:
59-
if sse.data.startswith("[DONE]"):
60-
break
61-
62-
data = sse.json()
63-
if is_mapping(data) and data.get("error"):
64-
message = None
65-
error = data.get("error")
66-
if is_mapping(error):
67-
message = error.get("message")
68-
if not message or not isinstance(message, str):
69-
message = "An error occurred during streaming"
70-
71-
raise APIError(
72-
message=message,
73-
request=self.response.request,
74-
body=data["error"],
75-
)
76-
77-
yield process_data(data=data, cast_to=cast_to, response=response)
78-
79-
# As we might not fully consume the response stream, we need to close it explicitly
80-
response.close()
58+
try:
59+
for sse in iterator:
60+
if sse.data.startswith("[DONE]"):
61+
break
62+
63+
data = sse.json()
64+
if is_mapping(data) and data.get("error"):
65+
message = None
66+
error = data.get("error")
67+
if is_mapping(error):
68+
message = error.get("message")
69+
if not message or not isinstance(message, str):
70+
message = "An error occurred during streaming"
71+
72+
raise APIError(
73+
message=message,
74+
request=self.response.request,
75+
body=data["error"],
76+
)
77+
78+
yield process_data(data=data, cast_to=cast_to, response=response)
79+
finally:
80+
# Ensure the response is closed even if the consumer doesn't read all data
81+
response.close()
8182

8283
def __enter__(self) -> Self:
8384
return self
@@ -136,29 +137,30 @@ async def __stream__(self) -> AsyncIterator[_T]:
136137
process_data = self._client._process_response_data
137138
iterator = self._iter_events()
138139

139-
async for sse in iterator:
140-
if sse.data.startswith("[DONE]"):
141-
break
142-
143-
data = sse.json()
144-
if is_mapping(data) and data.get("error"):
145-
message = None
146-
error = data.get("error")
147-
if is_mapping(error):
148-
message = error.get("message")
149-
if not message or not isinstance(message, str):
150-
message = "An error occurred during streaming"
151-
152-
raise APIError(
153-
message=message,
154-
request=self.response.request,
155-
body=data["error"],
156-
)
157-
158-
yield process_data(data=data, cast_to=cast_to, response=response)
159-
160-
# As we might not fully consume the response stream, we need to close it explicitly
161-
await response.aclose()
140+
try:
141+
async for sse in iterator:
142+
if sse.data.startswith("[DONE]"):
143+
break
144+
145+
data = sse.json()
146+
if is_mapping(data) and data.get("error"):
147+
message = None
148+
error = data.get("error")
149+
if is_mapping(error):
150+
message = error.get("message")
151+
if not message or not isinstance(message, str):
152+
message = "An error occurred during streaming"
153+
154+
raise APIError(
155+
message=message,
156+
request=self.response.request,
157+
body=data["error"],
158+
)
159+
160+
yield process_data(data=data, cast_to=cast_to, response=response)
161+
finally:
162+
# Ensure the response is closed even if the consumer doesn't read all data
163+
await response.aclose()
162164

163165
async def __aenter__(self) -> Self:
164166
return self

0 commit comments

Comments
 (0)