Skip to content

Commit 222f1ae

Browse files
committed
more careful handling of CRT native resources in callbacks
1 parent 1b80a50 commit 222f1ae

File tree

1 file changed

+58
-41
lines changed

1 file changed

+58
-41
lines changed

aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionNative.kt

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -83,55 +83,56 @@ private class HttpStreamContext(
8383
val nativeReq: CPointer<cnames.structs.aws_http_message>,
8484
)
8585

86+
private fun callbackError(): Int = aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
87+
8688
private fun onResponseHeaders(
8789
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
8890
blockType: aws_http_header_block,
8991
headerArray: CPointer<aws_http_header>?,
9092
numHeaders: size_t,
9193
userdata: COpaquePointer?,
9294
): Int {
93-
val ctxStableRef = userdata?.asStableRef<HttpStreamContext>() ?: return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
94-
ctxStableRef.use {
95-
val ctx = it.get()
96-
val stream = ctx.stream ?: return AWS_OP_ERR
97-
98-
val hdrCnt = numHeaders.toInt()
99-
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
100-
val kheaders = mutableListOf<HttpHeader>()
101-
for (i in 0 until hdrCnt) {
102-
val nativeHdr = headerArray[i]
103-
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
104-
kheaders.add(hdr)
105-
}
106-
kheaders
107-
} else {
108-
null
109-
}
110-
111-
try {
112-
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
113-
} catch (ex: Exception) {
114-
log(LogLevel.Error, "onResponseHeaders: $ex")
115-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
95+
val stableRef = dereferenceUserdata(userdata) ?: return callbackError()
96+
val ctx = stableRef.safeGet() ?: return callbackError()
97+
val stream = ctx.stream ?: return callbackError()
98+
99+
val hdrCnt = numHeaders.toInt()
100+
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
101+
val kheaders = mutableListOf<HttpHeader>()
102+
for (i in 0 until hdrCnt) {
103+
val nativeHdr = headerArray[i]
104+
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
105+
kheaders.add(hdr)
116106
}
107+
kheaders
108+
} else {
109+
null
110+
}
117111

118-
return AWS_OP_SUCCESS
112+
try {
113+
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
114+
} catch (ex: Exception) {
115+
log(LogLevel.Error, "onResponseHeaders: $ex")
116+
return callbackError()
119117
}
118+
119+
return AWS_OP_SUCCESS
120120
}
121121

122122
private fun onResponseHeaderBlockDone(
123123
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
124124
blockType: aws_http_header_block,
125125
userdata: COpaquePointer?,
126126
): Int {
127-
val ctx = userdata?.asStableRef<HttpStreamContext>()?.get() ?: return AWS_OP_ERR
128-
val stream = ctx.stream ?: return AWS_OP_ERR
127+
val stableRef = dereferenceUserdata(userdata) ?: return callbackError()
128+
val ctx = stableRef.safeGet() ?: return callbackError()
129+
val stream = ctx.stream ?: return callbackError()
129130

130131
try {
131132
ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt())
132133
} catch (ex: Exception) {
133134
log(LogLevel.Error, "onResponseHeaderBlockDone: $ex")
134-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
135+
return callbackError()
135136
}
136137

137138
return AWS_OP_SUCCESS
@@ -142,25 +143,23 @@ private fun onIncomingBody(
142143
data: CPointer<aws_byte_cursor>?,
143144
userdata: COpaquePointer?,
144145
): Int {
145-
val stableRef = try { userdata?.asStableRef<HttpStreamContext>() } catch (_: NullPointerException) { return AWS_OP_ERR }
146-
val ctx = try { stableRef?.get() } catch (_: NullPointerException) { return AWS_OP_ERR }
147-
if (ctx == null) return AWS_OP_ERR
148-
val stream = ctx.stream
149-
if (stream == null) return AWS_OP_ERR
146+
val stableRef = dereferenceUserdata(userdata) ?: return callbackError()
147+
val ctx = stableRef.safeGet() ?: return callbackError()
148+
val stream = ctx.stream ?: return callbackError()
150149

151150
try {
152151
val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty
153152
val windowIncrement = ctx.handler.onResponseBody(stream, body)
154153
if (windowIncrement < 0) {
155-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
154+
return callbackError()
156155
}
157156

158157
if (windowIncrement > 0) {
159158
aws_http_stream_update_window(nativeStream, windowIncrement.convert())
160159
}
161160
} catch (ex: Exception) {
162161
log(LogLevel.Error, "onIncomingBody: $ex")
163-
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
162+
return callbackError()
164163
}
165164

166165
return AWS_OP_SUCCESS
@@ -171,23 +170,41 @@ private fun onStreamComplete(
171170
errorCode: Int,
172171
userdata: COpaquePointer?,
173172
) {
174-
val stableRef = userdata?.asStableRef<HttpStreamContext>() ?: return
175-
val ctx = stableRef.get()
176-
val stream = ctx.stream ?: return
177-
173+
val stableRef = dereferenceUserdata(userdata) ?: return
178174
try {
179-
ctx.handler.onResponseComplete(stream, errorCode)
175+
val ctx = stableRef.safeGet() ?: return
176+
try {
177+
val stream = ctx.stream ?: return
178+
ctx.handler.onResponseComplete(stream, errorCode)
179+
} finally {
180+
// cleanup request object
181+
aws_http_message_release(ctx.nativeReq)
182+
}
180183
} catch (ex: Exception) {
181184
log(LogLevel.Error, "onStreamComplete: $ex")
182185
// close connection if callback throws an exception
183186
aws_http_connection_close(aws_http_stream_get_connection(nativeStream))
184187
} finally {
185-
// cleanup stream resources
188+
// cleanup userdata
186189
stableRef.dispose()
187-
aws_http_message_destroy(ctx.nativeReq)
188190
}
189191
}
190192

193+
private fun dereferenceUserdata(userdata: COpaquePointer?): StableRef<HttpStreamContext>? =
194+
try {
195+
userdata?.asStableRef<HttpStreamContext>()
196+
} catch (_: NullPointerException) {
197+
null
198+
}
199+
200+
private fun <T : Any> StableRef<T>.safeGet(): T? =
201+
try {
202+
get()
203+
} catch (_: NullPointerException) {
204+
// `get()` can throw `NullPointerException` when stream has been canceled and CRT is cleaning up resources
205+
null
206+
}
207+
191208
internal fun HttpRequest.toNativeRequest(): CPointer<cnames.structs.aws_http_message> {
192209
val nativeReq = checkNotNull(
193210
aws_http_message_new_request(Allocator.Default),

0 commit comments

Comments
 (0)