diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/auth/signing/AwsSignerNative.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/auth/signing/AwsSignerNative.kt index e8746653..8af98232 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/auth/signing/AwsSignerNative.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/auth/signing/AwsSignerNative.kt @@ -8,11 +8,7 @@ package aws.sdk.kotlin.crt.auth.signing import aws.sdk.kotlin.crt.* import aws.sdk.kotlin.crt.auth.credentials.Credentials import aws.sdk.kotlin.crt.http.* -import aws.sdk.kotlin.crt.util.asAwsByteCursor -import aws.sdk.kotlin.crt.util.initFromCursor -import aws.sdk.kotlin.crt.util.toAwsString -import aws.sdk.kotlin.crt.util.toKString -import aws.sdk.kotlin.crt.util.use +import aws.sdk.kotlin.crt.util.* import kotlinx.cinterop.* import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.runBlocking @@ -223,15 +219,10 @@ private fun AwsSigningConfig.toNativeSigningConfig(): CPointer Boolean private fun nativeShouldSignHeaderFn(headerName: CPointer?, userData: COpaquePointer?): Boolean { checkNotNull(headerName) { "aws_should_sign_header_fn expected non-null header name" } - if (userData == null) { - return true - } - - userData.asStableRef().use { - val kShouldSignHeaderFn = it.get() + return userData?.withDereferenced(dispose = true) { kShouldSignHeaderFn -> val kHeaderName = headerName.pointed.toKString() - return kShouldSignHeaderFn(kHeaderName) - } + kShouldSignHeaderFn(kHeaderName) + } ?: error("Expected non-null userData") } /** @@ -243,17 +234,17 @@ private fun signCallback(signingResult: CPointer?, errorCode checkNotNull(signingResult) { "signing callback received null aws_signing_result" } checkNotNull(userData) { "signing callback received null user data" } - val (pinnedRequestToSign, callbackChannel) = userData - .asStableRef>, Channel>>() - .get() + userData.withDereferenced>, Channel>> { pair -> + val (pinnedRequestToSign, callbackChannel) = pair - val requestToSign = pinnedRequestToSign.get() + val requestToSign = pinnedRequestToSign.get() - awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) { - "aws_apply_signing_result_to_http_request" - } + awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) { + "aws_apply_signing_result_to_http_request" + } - runBlocking { callbackChannel.send(signingResult.getSignature()) } + runBlocking { callbackChannel.send(signingResult.getSignature()) } + } } /** @@ -264,8 +255,9 @@ private fun signChunkCallback(signingResult: CPointer?, erro checkNotNull(signingResult) { "signing callback received null aws_signing_result" } checkNotNull(userData) { "signing callback received null user data" } - val callbackChannel = userData.asStableRef>().get() - runBlocking { callbackChannel.send(signingResult.getSignature()) } + userData.withDereferenced> { callbackChannel -> + runBlocking { callbackChannel.send(signingResult.getSignature()) } + } } private fun Credentials.toNativeCredentials(): CPointer? = aws_credentials_new_from_string( diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionManagerNative.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionManagerNative.kt index 0c7b452b..28e82786 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionManagerNative.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionManagerNative.kt @@ -183,13 +183,10 @@ private fun SocketDomain.toNativeSocketDomain() = when (this) { } private fun onShutdownComplete(userdata: COpaquePointer?) { - if (userdata == null) return - val notify = userdata.asStableRef() - with(notify.get()) { - trySend(Unit) - close() + userdata?.withDereferenced(dispose = true) { notify -> + notify.trySend(Unit) + notify.close() } - notify.dispose() } private data class HttpConnectionAcquisitionRequest( @@ -202,20 +199,16 @@ private fun onConnectionAcquired( errCode: Int, userdata: COpaquePointer?, ) { - if (userdata == null) return - val stableRef = userdata.asStableRef() - val request = stableRef.get() - - when { - errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode)) - conn == null -> request.cont.resumeWithException( - CrtRuntimeException("acquireConnection(): http connection null", ec = errCode), - ) - else -> { - val kconn = HttpClientConnectionNative(request.manager, conn) - request.cont.resume(kconn) + userdata?.withDereferenced(dispose = true) { request -> + when { + errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode)) + conn == null -> request.cont.resumeWithException( + CrtRuntimeException("acquireConnection(): http connection null", ec = errCode), + ) + else -> { + val kconn = HttpClientConnectionNative(request.manager, conn) + request.cont.resume(kconn) + } } } - - stableRef.dispose() } diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionNative.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionNative.kt index 8fa3298f..7db5ee0f 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionNative.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpClientConnectionNative.kt @@ -7,11 +7,7 @@ package aws.sdk.kotlin.crt.http import aws.sdk.kotlin.crt.* import aws.sdk.kotlin.crt.io.Buffer import aws.sdk.kotlin.crt.io.ByteCursorBuffer -import aws.sdk.kotlin.crt.util.asAwsByteCursor -import aws.sdk.kotlin.crt.util.initFromCursor -import aws.sdk.kotlin.crt.util.toKString -import aws.sdk.kotlin.crt.util.use -import aws.sdk.kotlin.crt.util.withAwsByteCursor +import aws.sdk.kotlin.crt.util.* import kotlinx.atomicfu.atomic import kotlinx.cinterop.* import libcrt.* @@ -87,105 +83,100 @@ private class HttpStreamContext( val nativeReq: CPointer, ) +private fun callbackError(): Int = aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) + private fun onResponseHeaders( nativeStream: CPointer?, blockType: aws_http_header_block, headerArray: CPointer?, numHeaders: size_t, userdata: COpaquePointer?, -): Int { - val ctxStableRef = userdata?.asStableRef() ?: return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) - ctxStableRef.use { - val ctx = it.get() - val stream = ctx.stream ?: return AWS_OP_ERR - - val hdrCnt = numHeaders.toInt() - val headers: List? = if (hdrCnt > 0 && headerArray != null) { - val kheaders = mutableListOf() - for (i in 0 until hdrCnt) { - val nativeHdr = headerArray[i] - val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString()) - kheaders.add(hdr) +): Int = + userdata?.withDereferenced { ctx -> + ctx.stream?.let { stream -> + val hdrCnt = numHeaders.toInt() + val headers: List? = if (hdrCnt > 0 && headerArray != null) { + val kheaders = mutableListOf() + for (i in 0 until hdrCnt) { + val nativeHdr = headerArray[i] + val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString()) + kheaders.add(hdr) + } + kheaders + } else { + null } - kheaders - } else { - null - } - try { - ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers) - } catch (ex: Exception) { - log(LogLevel.Error, "onResponseHeaders: $ex") - return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) + try { + ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers) + AWS_OP_SUCCESS + } catch (ex: Exception) { + log(LogLevel.Error, "onResponseHeaders: $ex") + null + } } - - return AWS_OP_SUCCESS - } -} + } ?: callbackError() private fun onResponseHeaderBlockDone( nativeStream: CPointer?, blockType: aws_http_header_block, userdata: COpaquePointer?, -): Int { - val ctx = userdata?.asStableRef()?.get() ?: return AWS_OP_ERR - val stream = ctx.stream ?: return AWS_OP_ERR - - try { - ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt()) - } catch (ex: Exception) { - log(LogLevel.Error, "onResponseHeaderBlockDone: $ex") - return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) - } - - return AWS_OP_SUCCESS -} +): Int = + userdata?.withDereferenced { ctx -> + ctx.stream?.let { stream -> + try { + ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt()) + AWS_OP_SUCCESS + } catch (ex: Exception) { + log(LogLevel.Error, "onResponseHeaderBlockDone: $ex") + null + } + } + } ?: callbackError() private fun onIncomingBody( nativeStream: CPointer?, data: CPointer?, userdata: COpaquePointer?, -): Int { - val ctx = userdata?.asStableRef()?.get() ?: return AWS_OP_ERR - val stream = ctx.stream ?: return AWS_OP_ERR - - try { - val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty - val windowIncrement = ctx.handler.onResponseBody(stream, body) - if (windowIncrement < 0) { - return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) - } - - if (windowIncrement > 0) { - aws_http_stream_update_window(nativeStream, windowIncrement.convert()) +): Int = + userdata?.withDereferenced { ctx -> + ctx.stream?.let { stream -> + try { + val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty + val windowIncrement = ctx.handler.onResponseBody(stream, body) + + if (windowIncrement < 0) { + null + } else { + if (windowIncrement > 0) { + aws_http_stream_update_window(nativeStream, windowIncrement.convert()) + } + AWS_OP_SUCCESS + } + } catch (ex: Exception) { + log(LogLevel.Error, "onIncomingBody: $ex") + null + } } - } catch (ex: Exception) { - log(LogLevel.Error, "onIncomingBody: $ex") - return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) - } - - return AWS_OP_SUCCESS -} + } ?: callbackError() private fun onStreamComplete( nativeStream: CPointer?, errorCode: Int, userdata: COpaquePointer?, ) { - val stableRef = userdata?.asStableRef() ?: return - val ctx = stableRef.get() - val stream = ctx.stream ?: return - - try { - ctx.handler.onResponseComplete(stream, errorCode) - } catch (ex: Exception) { - log(LogLevel.Error, "onStreamComplete: $ex") - // close connection if callback throws an exception - aws_http_connection_close(aws_http_stream_get_connection(nativeStream)) - } finally { - // cleanup stream resources - stableRef.dispose() - aws_http_message_destroy(ctx.nativeReq) + userdata?.withDereferenced(dispose = true) { ctx -> + try { + val stream = ctx.stream ?: return + ctx.handler.onResponseComplete(stream, errorCode) + } catch (ex: Exception) { + log(LogLevel.Error, "onStreamComplete: $ex") + // close connection if callback throws an exception + aws_http_connection_close(aws_http_stream_get_connection(nativeStream)) + } finally { + // cleanup request object + aws_http_message_release(ctx.nativeReq) + } } } diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpStreamNative.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpStreamNative.kt index 38814421..f3012838 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpStreamNative.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/HttpStreamNative.kt @@ -10,6 +10,7 @@ import aws.sdk.kotlin.crt.NativeHandle import aws.sdk.kotlin.crt.awsAssertOpSuccess import aws.sdk.kotlin.crt.util.asAwsByteCursor import aws.sdk.kotlin.crt.util.use +import aws.sdk.kotlin.crt.util.withDereferenced import kotlinx.atomicfu.atomic import kotlinx.cinterop.* import libcrt.* @@ -67,12 +68,13 @@ internal class HttpStreamNative( throw CrtRuntimeException("aws_input_stream_new_from_cursor()") } - StableRef.create(WriteChunkRequest(cont, byteBuf, stream)).use { req -> + val req = WriteChunkRequest(cont, byteBuf, stream) + StableRef.create(req).use { stableRef -> val chunkOpts = cValue { chunk_data_size = chunkData.size.convert() chunk_data = stream on_complete = staticCFunction(::onWriteChunkComplete) - user_data = req.asCPointer() + user_data = stableRef.asCPointer() } awsAssertOpSuccess( aws_http1_stream_write_chunk(ptr, chunkOpts), @@ -113,19 +115,18 @@ private fun onWriteChunkComplete( userData: COpaquePointer?, ) { if (stream == null) return - val stableRef = userData?.asStableRef() ?: return - val req = stableRef.get() - when { - errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode)) - else -> req.cont.resume(Unit) + userData?.withDereferenced { req -> + checkNotNull(req) { "Received null request in onWriteChunkComplete" } + when { + errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode)) + else -> req.cont.resume(Unit) + } + cleanupWriteChunkCbData(req) } - cleanupWriteChunkCbData(stableRef) } -private fun cleanupWriteChunkCbData(stableRef: StableRef) { - val req = stableRef.get() +private fun cleanupWriteChunkCbData(req: WriteChunkRequest) { aws_input_stream_destroy(req.inputStream) aws_byte_buf_clean_up(req.chunkData) Allocator.Default.free(req.inputStream) - stableRef.dispose() } diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/RequestBodyStream.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/RequestBodyStream.kt index 1dee8ca4..063a6843 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/RequestBodyStream.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/http/RequestBodyStream.kt @@ -8,6 +8,7 @@ import aws.sdk.kotlin.crt.Allocator import aws.sdk.kotlin.crt.LogLevel import aws.sdk.kotlin.crt.io.MutableBuffer import aws.sdk.kotlin.crt.log +import aws.sdk.kotlin.crt.util.withDereferenced import kotlinx.cinterop.* import libcrt.* @@ -17,22 +18,24 @@ private fun streamSeek( basis: aws_stream_seek_basis, ): Int { if (stream == null || basis != AWS_SSB_BEGIN || offset != 0L) return AWS_OP_ERR - val handler = stream.pointed.impl?.asStableRef()?.get() ?: return AWS_OP_ERR - var result = AWS_OP_SUCCESS - try { - if (!handler.resetPosition()) { - result = AWS_OP_ERR + return stream.pointed.impl?.withDereferenced { handler -> + var result = AWS_OP_SUCCESS + + try { + if (!handler.resetPosition()) { + result = AWS_OP_ERR + } + } catch (ex: Exception) { + log(LogLevel.Error, "streamSeek: $ex") + return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) } - } catch (ex: Exception) { - log(LogLevel.Error, "streamSeek: $ex") - return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) - } - if (result == AWS_OP_SUCCESS) { - handler.bodyDone = false - } - return result + if (result == AWS_OP_SUCCESS) { + handler.bodyDone = false + } + result + } ?: AWS_OP_ERR } private fun streamRead( @@ -40,21 +43,23 @@ private fun streamRead( dest: CPointer?, ): Int { if (stream == null || dest == null) return AWS_OP_ERR - val handler = stream.pointed.impl?.asStableRef()?.get() ?: return AWS_OP_ERR - if (handler.bodyDone) return AWS_OP_SUCCESS - - try { - // MutableBuffer handles updating dest->len - val buffer = MutableBuffer(dest) - if (handler.khandler.sendRequestBody(buffer)) { - handler.bodyDone = true + return stream.pointed.impl?.withDereferenced { handler -> + if (handler.bodyDone) { + AWS_OP_SUCCESS + } else { + try { + // MutableBuffer handles updating dest->len + val buffer = MutableBuffer(dest) + if (handler.khandler.sendRequestBody(buffer)) { + handler.bodyDone = true + } + AWS_OP_SUCCESS + } catch (ex: Exception) { + log(LogLevel.Error, "streamRead: $ex") + aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) + } } - } catch (ex: Exception) { - log(LogLevel.Error, "streamRead: $ex") - return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt()) - } - - return AWS_OP_SUCCESS + } ?: AWS_OP_ERR } private fun streamGetStatus( @@ -62,10 +67,11 @@ private fun streamGetStatus( status: CPointer?, ): Int { if (stream == null || status == null) return AWS_OP_ERR - val handler = stream.pointed.impl?.asStableRef()?.get() ?: return AWS_OP_ERR - status.pointed.is_end_of_stream = handler.bodyDone - status.pointed.is_valid = true - return AWS_OP_SUCCESS + return stream.pointed.impl?.withDereferenced { handler -> + status.pointed.is_end_of_stream = handler.bodyDone + status.pointed.is_valid = true + AWS_OP_SUCCESS + } ?: AWS_OP_ERR } @Suppress("unused") @@ -87,9 +93,9 @@ private fun streamRelease( if (stream == null) return val refCnt = aws_ref_count_release(stream.pointed.ref_count.ptr) if (refCnt.toInt() == 0) { - log(LogLevel.Trace, "releasing RequestBodyStream") - val stableRef = stream.pointed.impl?.asStableRef() ?: return - stableRef.dispose() + stream.pointed.impl?.withDereferenced(dispose = true) { _ -> + log(LogLevel.Trace, "releasing RequestBodyStream") + } Allocator.Default.free(stream) } } @@ -117,10 +123,10 @@ internal fun inputStream(khandler: HttpRequestBodyStream): CPointer.toHttpRequestBodyStream(): HttpRequestBodyStream { - val stableRef = checkNotNull(this.pointed.impl?.asStableRef()) { "toHttpRequestBodyStream() expected non-null `impl`" } - return stableRef.get().khandler -} +internal fun CPointer.toHttpRequestBodyStream(): HttpRequestBodyStream = + pointed.impl?.withDereferenced { handler -> + handler.khandler + } ?: error("toHttpRequestBodyStream() expected non-null `impl`") // wrapper around the actual implementation private class RequestBodyStream( diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/ClientBootstrapNative.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/ClientBootstrapNative.kt index d291aeb0..28b9efb4 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/ClientBootstrapNative.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/ClientBootstrapNative.kt @@ -8,6 +8,7 @@ package aws.sdk.kotlin.crt.io import aws.sdk.kotlin.crt.* import aws.sdk.kotlin.crt.util.ShutdownChannel import aws.sdk.kotlin.crt.util.shutdownChannel +import aws.sdk.kotlin.crt.util.withDereferenced import kotlinx.cinterop.* import libcrt.aws_client_bootstrap import libcrt.aws_client_bootstrap_new @@ -61,10 +62,8 @@ public actual class ClientBootstrap private constructor( @OptIn(ExperimentalForeignApi::class) private fun onShutdownComplete(userData: COpaquePointer?) { - if (userData == null) return - val stableRef = userData.asStableRef() - val ch = stableRef.get() - ch.trySend(Unit) - ch.close() - stableRef.dispose() + userData?.withDereferenced { ch -> + ch.trySend(Unit) + ch.close() + } } diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/EventLoopGroupNative.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/EventLoopGroupNative.kt index 3f5d1c4f..9ca9bc9c 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/EventLoopGroupNative.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/EventLoopGroupNative.kt @@ -8,6 +8,7 @@ package aws.sdk.kotlin.crt.io import aws.sdk.kotlin.crt.* import aws.sdk.kotlin.crt.util.ShutdownChannel import aws.sdk.kotlin.crt.util.shutdownChannel +import aws.sdk.kotlin.crt.util.withDereferenced import cnames.structs.aws_event_loop_group import kotlinx.cinterop.* import libcrt.aws_event_loop_group_new @@ -64,10 +65,8 @@ public actual class EventLoopGroup actual constructor(maxThreads: Int) : @OptIn(ExperimentalForeignApi::class) private fun onShutdownComplete(userData: COpaquePointer?) { - if (userData == null) return - val stableRef = userData.asStableRef() - val ch = stableRef.get() - ch.trySend(Unit) - ch.close() - stableRef.dispose() + userData?.withDereferenced(dispose = true) { ch -> + ch?.trySend(Unit) + ch?.close() + } } diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/HostResolverNative.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/HostResolverNative.kt index 7a343260..26718c8e 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/HostResolverNative.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/io/HostResolverNative.kt @@ -6,10 +6,7 @@ package aws.sdk.kotlin.crt.io import aws.sdk.kotlin.crt.* -import aws.sdk.kotlin.crt.util.ShutdownChannel -import aws.sdk.kotlin.crt.util.shutdownChannel -import aws.sdk.kotlin.crt.util.toAwsString -import aws.sdk.kotlin.crt.util.toKString +import aws.sdk.kotlin.crt.util.* import kotlinx.cinterop.* import kotlinx.coroutines.channels.Channel import libcrt.* @@ -82,14 +79,10 @@ public actual class HostResolver private constructor( @OptIn(ExperimentalForeignApi::class) private fun onShutdownComplete(userData: COpaquePointer?) { - if (userData == null) { - return + userData?.withDereferenced(dispose = true) { ch -> + ch.trySend(Unit) + ch.close() } - val stableRef = userData.asStableRef() - val ch = stableRef.get() - ch.trySend(Unit) - ch.close() - stableRef.dispose() } // implementation of `aws_on_host_resolved_result_fn`: https://github.com/awslabs/aws-c-io/blob/db7a1bddc9a29eca18734d0af189c3924775dcf1/include/aws/io/host_resolver.h#L53C14-L53C44 @@ -104,58 +97,57 @@ private fun awsOnHostResolveFn( throw CrtRuntimeException("aws_on_host_resolved_result_fn: userData unexpectedly null") } - val stableRef = userData.asStableRef>>>() - val channel = stableRef.get() - - try { - if (errCode != AWS_OP_SUCCESS) { - throw CrtRuntimeException("aws_on_host_resolved_result_fn", ec = errCode) - } - - val length = aws_array_list_length(hostAddresses) - if (length == 0uL) { - throw CrtRuntimeException("Failed to resolve host address for ${hostName?.toKString()}") - } - - val addressList = ArrayList(length.toInt()) - - val element = alloc() - for (i in 0uL until length) { - awsAssertOpSuccess( - aws_array_list_get_at_ptr( - hostAddresses, - element.ptr, - i, - ), - ) { "aws_array_list_get_at_ptr failed at index $i" } - - val elemOpaque = element.value ?: run { - throw CrtRuntimeException("aws_host_addresses value at index $i unexpectedly null") + userData.withDereferenced>>> { channel -> + try { + if (errCode != AWS_OP_SUCCESS) { + throw CrtRuntimeException("aws_on_host_resolved_result_fn", ec = errCode) } - val addr = elemOpaque.reinterpret().pointed - - val hostStr = addr.host?.toKString() ?: run { - throw CrtRuntimeException("aws_host_addresses `host` at index $i unexpectedly null") - } - val addressStr = addr.address?.toKString() ?: run { - throw CrtRuntimeException("aws_host_addresses `address` at index $i unexpectedly null") + val length = aws_array_list_length(hostAddresses) + if (length == 0uL) { + throw CrtRuntimeException("Failed to resolve host address for ${hostName?.toKString()}") } - val addressType = when (addr.record_type) { - aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_A -> AddressType.IpV4 - aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_AAAA -> AddressType.IpV6 - else -> throw CrtRuntimeException("received unsupported aws_host_address `aws_address_record_type`: ${addr.record_type}") + val addressList = ArrayList(length.toInt()) + + val element = alloc() + for (i in 0uL until length) { + awsAssertOpSuccess( + aws_array_list_get_at_ptr( + hostAddresses, + element.ptr, + i, + ), + ) { "aws_array_list_get_at_ptr failed at index $i" } + + val elemOpaque = element.value ?: run { + throw CrtRuntimeException("aws_host_addresses value at index $i unexpectedly null") + } + + val addr = elemOpaque.reinterpret().pointed + + val hostStr = addr.host?.toKString() ?: run { + throw CrtRuntimeException("aws_host_addresses `host` at index $i unexpectedly null") + } + val addressStr = addr.address?.toKString() ?: run { + throw CrtRuntimeException("aws_host_addresses `address` at index $i unexpectedly null") + } + + val addressType = when (addr.record_type) { + aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_A -> AddressType.IpV4 + aws_address_record_type.AWS_ADDRESS_RECORD_TYPE_AAAA -> AddressType.IpV6 + else -> throw CrtRuntimeException("received unsupported aws_host_address `aws_address_record_type`: ${addr.record_type}") + } + + addressList += CrtHostAddress(host = hostStr, address = addressStr, addressType) } - addressList += CrtHostAddress(host = hostStr, address = addressStr, addressType) + channel.trySend(Result.success(addressList)) + } catch (e: Exception) { + channel.trySend(Result.failure(e)) + } finally { + channel.close() } - - channel.trySend(Result.success(addressList)) - } catch (e: Exception) { - channel.trySend(Result.failure(e)) - } finally { - channel.close() } } diff --git a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/util/Interop.kt b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/util/Interop.kt index 28e1bc8c..0f98da47 100644 --- a/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/util/Interop.kt +++ b/aws-crt-kotlin/native/src/aws/sdk/kotlin/crt/util/Interop.kt @@ -4,7 +4,9 @@ */ package aws.sdk.kotlin.crt.util +import kotlinx.cinterop.COpaquePointer import kotlinx.cinterop.StableRef +import kotlinx.cinterop.asStableRef import kotlinx.coroutines.channels.Channel /** @@ -27,3 +29,26 @@ internal inline fun StableRef.use(block: (StableRef) -> R): R dispose() } } + +internal inline fun COpaquePointer.withDereferenced( + dispose: Boolean = false, + block: (T) -> R, +): R? = + try { + val stableRef = asStableRef() // can throw NPE when target type can't be coerced to type arg + try { + val value = stableRef.get() // can throw NPE when pointer has been cleaned up by CRT + block(value) + } finally { + if (dispose) { + stableRef.dispose() + } + } + } catch (_: NullPointerException) { + null + } + +internal inline fun COpaquePointer.withDereferenced( + dispose: Boolean = false, + block: (T) -> Unit, +) = withDereferenced(dispose, block) diff --git a/build-support/src/main/kotlin/aws/sdk/kotlin/gradle/crt/CMakeUtils.kt b/build-support/src/main/kotlin/aws/sdk/kotlin/gradle/crt/CMakeUtils.kt index 2f6ae476..d8223b66 100644 --- a/build-support/src/main/kotlin/aws/sdk/kotlin/gradle/crt/CMakeUtils.kt +++ b/build-support/src/main/kotlin/aws/sdk/kotlin/gradle/crt/CMakeUtils.kt @@ -66,7 +66,6 @@ val KonanTarget.osxArchitectureName KonanTarget.WATCHOS_ARM64 -> "arm64_32" else -> null } - else -> null } fun Project.cmakeBuildDir(target: KotlinNativeTarget): File =