Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@ package aws.sdk.kotlin.crt.auth.signing
import aws.sdk.kotlin.crt.*
import aws.sdk.kotlin.crt.auth.credentials.Credentials
import aws.sdk.kotlin.crt.http.*
import aws.sdk.kotlin.crt.util.asAwsByteCursor
import aws.sdk.kotlin.crt.util.initFromCursor
import aws.sdk.kotlin.crt.util.toAwsString
import aws.sdk.kotlin.crt.util.toKString
import aws.sdk.kotlin.crt.util.use
import aws.sdk.kotlin.crt.util.*
import kotlinx.cinterop.*
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.runBlocking
Expand Down Expand Up @@ -223,15 +219,10 @@ private fun AwsSigningConfig.toNativeSigningConfig(): CPointer<aws_signing_confi
private typealias ShouldSignHeaderFunction = (String) -> Boolean
private fun nativeShouldSignHeaderFn(headerName: CPointer<aws_byte_cursor>?, userData: COpaquePointer?): Boolean {
checkNotNull(headerName) { "aws_should_sign_header_fn expected non-null header name" }
if (userData == null) {
return true
}

userData.asStableRef<ShouldSignHeaderFunction>().use {
val kShouldSignHeaderFn = it.get()
return userData?.withDereferenced<ShouldSignHeaderFunction, _>(dispose = true) { kShouldSignHeaderFn ->
val kHeaderName = headerName.pointed.toKString()
return kShouldSignHeaderFn(kHeaderName)
}
kShouldSignHeaderFn(kHeaderName)
} ?: error("Expected non-null userData")
}

/**
Expand All @@ -243,17 +234,17 @@ private fun signCallback(signingResult: CPointer<aws_signing_result>?, errorCode
checkNotNull(signingResult) { "signing callback received null aws_signing_result" }
checkNotNull(userData) { "signing callback received null user data" }

val (pinnedRequestToSign, callbackChannel) = userData
.asStableRef<Pair<Pinned<CPointer<cnames.structs.aws_http_message>>, Channel<ByteArray>>>()
.get()
userData.withDereferenced<Pair<Pinned<CPointer<cnames.structs.aws_http_message>>, Channel<ByteArray>>> { pair ->
val (pinnedRequestToSign, callbackChannel) = pair

val requestToSign = pinnedRequestToSign.get()
val requestToSign = pinnedRequestToSign.get()

awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) {
"aws_apply_signing_result_to_http_request"
}
awsAssertOpSuccess(aws_apply_signing_result_to_http_request(requestToSign, Allocator.Default.allocator, signingResult)) {
"aws_apply_signing_result_to_http_request"
}

runBlocking { callbackChannel.send(signingResult.getSignature()) }
runBlocking { callbackChannel.send(signingResult.getSignature()) }
}
}

/**
Expand All @@ -264,8 +255,9 @@ private fun signChunkCallback(signingResult: CPointer<aws_signing_result>?, erro
checkNotNull(signingResult) { "signing callback received null aws_signing_result" }
checkNotNull(userData) { "signing callback received null user data" }

val callbackChannel = userData.asStableRef<Channel<ByteArray>>().get()
runBlocking { callbackChannel.send(signingResult.getSignature()) }
userData.withDereferenced<Channel<ByteArray>> { callbackChannel ->
runBlocking { callbackChannel.send(signingResult.getSignature()) }
}
}

private fun Credentials.toNativeCredentials(): CPointer<cnames.structs.aws_credentials>? = aws_credentials_new_from_string(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,10 @@ private fun SocketDomain.toNativeSocketDomain() = when (this) {
}

private fun onShutdownComplete(userdata: COpaquePointer?) {
if (userdata == null) return
val notify = userdata.asStableRef<ShutdownChannel>()
with(notify.get()) {
trySend(Unit)
close()
userdata?.withDereferenced<ShutdownChannel>(dispose = true) { notify ->
notify.trySend(Unit)
notify.close()
}
notify.dispose()
}

private data class HttpConnectionAcquisitionRequest(
Expand All @@ -202,20 +199,16 @@ private fun onConnectionAcquired(
errCode: Int,
userdata: COpaquePointer?,
) {
if (userdata == null) return
val stableRef = userdata.asStableRef<HttpConnectionAcquisitionRequest>()
val request = stableRef.get()

when {
errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode))
conn == null -> request.cont.resumeWithException(
CrtRuntimeException("acquireConnection(): http connection null", ec = errCode),
)
else -> {
val kconn = HttpClientConnectionNative(request.manager, conn)
request.cont.resume(kconn)
userdata?.withDereferenced<HttpConnectionAcquisitionRequest>(dispose = true) { request ->
when {
errCode != AWS_OP_SUCCESS -> request.cont.resumeWithException(HttpException(errCode))
conn == null -> request.cont.resumeWithException(
CrtRuntimeException("acquireConnection(): http connection null", ec = errCode),
)
else -> {
val kconn = HttpClientConnectionNative(request.manager, conn)
request.cont.resume(kconn)
}
}
}

stableRef.dispose()
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@ package aws.sdk.kotlin.crt.http
import aws.sdk.kotlin.crt.*
import aws.sdk.kotlin.crt.io.Buffer
import aws.sdk.kotlin.crt.io.ByteCursorBuffer
import aws.sdk.kotlin.crt.util.asAwsByteCursor
import aws.sdk.kotlin.crt.util.initFromCursor
import aws.sdk.kotlin.crt.util.toKString
import aws.sdk.kotlin.crt.util.use
import aws.sdk.kotlin.crt.util.withAwsByteCursor
import aws.sdk.kotlin.crt.util.*
import kotlinx.atomicfu.atomic
import kotlinx.cinterop.*
import libcrt.*
Expand Down Expand Up @@ -87,105 +83,100 @@ private class HttpStreamContext(
val nativeReq: CPointer<cnames.structs.aws_http_message>,
)

private fun callbackError(): Int = aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())

private fun onResponseHeaders(
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
blockType: aws_http_header_block,
headerArray: CPointer<aws_http_header>?,
numHeaders: size_t,
userdata: COpaquePointer?,
): Int {
val ctxStableRef = userdata?.asStableRef<HttpStreamContext>() ?: return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
ctxStableRef.use {
val ctx = it.get()
val stream = ctx.stream ?: return AWS_OP_ERR

val hdrCnt = numHeaders.toInt()
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
val kheaders = mutableListOf<HttpHeader>()
for (i in 0 until hdrCnt) {
val nativeHdr = headerArray[i]
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
kheaders.add(hdr)
): Int =
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
ctx.stream?.let { stream ->
val hdrCnt = numHeaders.toInt()
val headers: List<HttpHeader>? = if (hdrCnt > 0 && headerArray != null) {
val kheaders = mutableListOf<HttpHeader>()
for (i in 0 until hdrCnt) {
val nativeHdr = headerArray[i]
val hdr = HttpHeader(nativeHdr.name.toKString(), nativeHdr.value.toKString())
kheaders.add(hdr)
}
kheaders
} else {
null
}
kheaders
} else {
null
}

try {
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
} catch (ex: Exception) {
log(LogLevel.Error, "onResponseHeaders: $ex")
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
try {
ctx.handler.onResponseHeaders(stream, stream.responseStatusCode, blockType.value.toInt(), headers)
AWS_OP_SUCCESS
} catch (ex: Exception) {
log(LogLevel.Error, "onResponseHeaders: $ex")
null
}
}

return AWS_OP_SUCCESS
}
}
} ?: callbackError()

private fun onResponseHeaderBlockDone(
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
blockType: aws_http_header_block,
userdata: COpaquePointer?,
): Int {
val ctx = userdata?.asStableRef<HttpStreamContext>()?.get() ?: return AWS_OP_ERR
val stream = ctx.stream ?: return AWS_OP_ERR

try {
ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt())
} catch (ex: Exception) {
log(LogLevel.Error, "onResponseHeaderBlockDone: $ex")
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
}

return AWS_OP_SUCCESS
}
): Int =
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
ctx.stream?.let { stream ->
try {
ctx.handler.onResponseHeadersDone(stream, blockType.value.toInt())
AWS_OP_SUCCESS
} catch (ex: Exception) {
log(LogLevel.Error, "onResponseHeaderBlockDone: $ex")
null
}
}
} ?: callbackError()

private fun onIncomingBody(
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
data: CPointer<aws_byte_cursor>?,
userdata: COpaquePointer?,
): Int {
val ctx = userdata?.asStableRef<HttpStreamContext>()?.get() ?: return AWS_OP_ERR
val stream = ctx.stream ?: return AWS_OP_ERR

try {
val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty
val windowIncrement = ctx.handler.onResponseBody(stream, body)
if (windowIncrement < 0) {
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
}

if (windowIncrement > 0) {
aws_http_stream_update_window(nativeStream, windowIncrement.convert())
): Int =
userdata?.withDereferenced<HttpStreamContext, _> { ctx ->
ctx.stream?.let { stream ->
try {
val body = if (data != null) ByteCursorBuffer(data) else Buffer.Empty
val windowIncrement = ctx.handler.onResponseBody(stream, body)

if (windowIncrement < 0) {
null
} else {
if (windowIncrement > 0) {
aws_http_stream_update_window(nativeStream, windowIncrement.convert())
}
AWS_OP_SUCCESS
}
} catch (ex: Exception) {
log(LogLevel.Error, "onIncomingBody: $ex")
null
}
}
} catch (ex: Exception) {
log(LogLevel.Error, "onIncomingBody: $ex")
return aws_raise_error(AWS_ERROR_HTTP_CALLBACK_FAILURE.toInt())
}

return AWS_OP_SUCCESS
}
} ?: callbackError()

private fun onStreamComplete(
nativeStream: CPointer<cnames.structs.aws_http_stream>?,
errorCode: Int,
userdata: COpaquePointer?,
) {
val stableRef = userdata?.asStableRef<HttpStreamContext>() ?: return
val ctx = stableRef.get()
val stream = ctx.stream ?: return

try {
ctx.handler.onResponseComplete(stream, errorCode)
} catch (ex: Exception) {
log(LogLevel.Error, "onStreamComplete: $ex")
// close connection if callback throws an exception
aws_http_connection_close(aws_http_stream_get_connection(nativeStream))
} finally {
// cleanup stream resources
stableRef.dispose()
aws_http_message_destroy(ctx.nativeReq)
userdata?.withDereferenced<HttpStreamContext>(dispose = true) { ctx ->
try {
val stream = ctx.stream ?: return
ctx.handler.onResponseComplete(stream, errorCode)
} catch (ex: Exception) {
log(LogLevel.Error, "onStreamComplete: $ex")
// close connection if callback throws an exception
aws_http_connection_close(aws_http_stream_get_connection(nativeStream))
} finally {
// cleanup request object
aws_http_message_release(ctx.nativeReq)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import aws.sdk.kotlin.crt.NativeHandle
import aws.sdk.kotlin.crt.awsAssertOpSuccess
import aws.sdk.kotlin.crt.util.asAwsByteCursor
import aws.sdk.kotlin.crt.util.use
import aws.sdk.kotlin.crt.util.withDereferenced
import kotlinx.atomicfu.atomic
import kotlinx.cinterop.*
import libcrt.*
Expand Down Expand Up @@ -67,12 +68,13 @@ internal class HttpStreamNative(
throw CrtRuntimeException("aws_input_stream_new_from_cursor()")
}

StableRef.create(WriteChunkRequest(cont, byteBuf, stream)).use { req ->
val req = WriteChunkRequest(cont, byteBuf, stream)
StableRef.create(req).use { stableRef ->
val chunkOpts = cValue<aws_http1_chunk_options> {
chunk_data_size = chunkData.size.convert()
chunk_data = stream
on_complete = staticCFunction(::onWriteChunkComplete)
user_data = req.asCPointer()
user_data = stableRef.asCPointer()
}
awsAssertOpSuccess(
aws_http1_stream_write_chunk(ptr, chunkOpts),
Expand Down Expand Up @@ -113,19 +115,18 @@ private fun onWriteChunkComplete(
userData: COpaquePointer?,
) {
if (stream == null) return
val stableRef = userData?.asStableRef<WriteChunkRequest>() ?: return
val req = stableRef.get()
when {
errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode))
else -> req.cont.resume(Unit)
userData?.withDereferenced<WriteChunkRequest> { req ->
checkNotNull(req) { "Received null request in onWriteChunkComplete" }
when {
errCode != AWS_OP_SUCCESS -> req.cont.resumeWithException(HttpException(errCode))
else -> req.cont.resume(Unit)
}
cleanupWriteChunkCbData(req)
}
cleanupWriteChunkCbData(stableRef)
}

private fun cleanupWriteChunkCbData(stableRef: StableRef<WriteChunkRequest>) {
val req = stableRef.get()
private fun cleanupWriteChunkCbData(req: WriteChunkRequest) {
aws_input_stream_destroy(req.inputStream)
aws_byte_buf_clean_up(req.chunkData)
Allocator.Default.free(req.inputStream)
stableRef.dispose()
}
Loading
Loading