@@ -10,7 +10,7 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject {
10
10
final class PSQLChannelHandler : ChannelDuplexHandler {
11
11
typealias OutboundIn = PSQLTask
12
12
typealias InboundIn = ByteBuffer
13
- typealias OutboundOut = PSQLFrontendMessage
13
+ typealias OutboundOut = ByteBuffer
14
14
15
15
private let logger : Logger
16
16
private var state : ConnectionStateMachine {
@@ -25,32 +25,33 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
25
25
private var handlerContext : ChannelHandlerContext !
26
26
private var rowStream : PSQLRowStream ?
27
27
private var decoder : NIOSingleStepByteToMessageProcessor < PSQLBackendMessageDecoder >
28
- private let authentificationConfiguration : PSQLConnection . Configuration . Authentication ?
28
+ private var encoder : BufferedMessageEncoder < PSQLFrontendMessageEncoder > !
29
+ private let configuration : PSQLConnection . Configuration
29
30
private let configureSSLCallback : ( ( Channel ) throws -> Void ) ?
30
31
31
32
/// this delegate should only be accessed on the connections `EventLoop`
32
33
weak var notificationDelegate : PSQLChannelHandlerNotificationDelegate ?
33
34
34
- init ( authentification : PSQLConnection . Configuration . Authentication ? ,
35
+ init ( configuration : PSQLConnection . Configuration ,
35
36
logger: Logger ,
36
37
configureSSLCallback: ( ( Channel ) throws -> Void ) ? )
37
38
{
38
39
self . state = ConnectionStateMachine ( )
39
- self . authentificationConfiguration = authentification
40
+ self . configuration = configuration
40
41
self . configureSSLCallback = configureSSLCallback
41
42
self . logger = logger
42
43
self . decoder = NIOSingleStepByteToMessageProcessor ( PSQLBackendMessageDecoder ( ) )
43
44
}
44
45
45
46
#if DEBUG
46
47
/// for testing purposes only
47
- init ( authentification : PSQLConnection . Configuration . Authentication ? ,
48
+ init ( configuration : PSQLConnection . Configuration ,
48
49
state: ConnectionStateMachine = . init( . initialized) ,
49
50
logger: Logger = . psqlNoOpLogger,
50
51
configureSSLCallback: ( ( Channel ) throws -> Void ) ? )
51
52
{
52
53
self . state = state
53
- self . authentificationConfiguration = authentification
54
+ self . configuration = configuration
54
55
self . configureSSLCallback = configureSSLCallback
55
56
self . logger = logger
56
57
self . decoder = NIOSingleStepByteToMessageProcessor ( PSQLBackendMessageDecoder ( ) )
@@ -61,6 +62,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
61
62
62
63
func handlerAdded( context: ChannelHandlerContext ) {
63
64
self . handlerContext = context
65
+ self . encoder = BufferedMessageEncoder (
66
+ buffer: context. channel. allocator. buffer ( capacity: 256 ) ,
67
+ encoder: PSQLFrontendMessageEncoder ( jsonEncoder: self . configuration. coders. jsonEncoder)
68
+ )
69
+
64
70
if context. channel. isActive {
65
71
self . connected ( context: context)
66
72
}
@@ -222,15 +228,19 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
222
228
case . wait:
223
229
break
224
230
case . sendStartupMessage( let authContext) :
225
- context. writeAndFlush ( . startup( . versionThree( parameters: authContext. toStartupParameters ( ) ) ) , promise: nil )
231
+ try ! self . encoder. encode ( . startup( . versionThree( parameters: authContext. toStartupParameters ( ) ) ) )
232
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
226
233
case . sendSSLRequest:
227
- context. writeAndFlush ( . sslRequest( . init( ) ) , promise: nil )
234
+ try ! self . encoder. encode ( . sslRequest( . init( ) ) )
235
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
228
236
case . sendPasswordMessage( let mode, let authContext) :
229
237
self . sendPasswordMessage ( mode: mode, authContext: authContext, context: context)
230
238
case . sendSaslInitialResponse( let name, let initialResponse) :
231
- context. writeAndFlush ( . saslInitialResponse( . init( saslMechanism: name, initialData: initialResponse) ) )
239
+ try ! self . encoder. encode ( . saslInitialResponse( . init( saslMechanism: name, initialData: initialResponse) ) )
240
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
232
241
case . sendSaslResponse( let bytes) :
233
- context. writeAndFlush ( . saslResponse( . init( data: bytes) ) )
242
+ try ! self . encoder. encode ( . saslResponse( . init( data: bytes) ) )
243
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
234
244
case . closeConnectionAndCleanup( let cleanupContext) :
235
245
self . closeConnectionAndCleanup ( cleanupContext, context: context)
236
246
case . fireChannelInactive:
@@ -277,7 +287,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
277
287
case . provideAuthenticationContext:
278
288
context. fireUserInboundEventTriggered ( PSQLEvent . readyForStartup)
279
289
280
- if let authentication = self . authentificationConfiguration {
290
+ if let authentication = self . configuration . authentication {
281
291
let authContext = AuthContext (
282
292
username: authentication. username,
283
293
password: authentication. password,
@@ -293,7 +303,8 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
293
303
// The normal, graceful termination procedure is that the frontend sends a Terminate
294
304
// message and immediately closes the connection. On receipt of this message, the
295
305
// backend closes the connection and terminates.
296
- context. write ( . terminate, promise: nil )
306
+ try ! self . encoder. encode ( . terminate)
307
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
297
308
}
298
309
context. close ( mode: . all, promise: promise)
299
310
case . succeedPreparedStatementCreation( let preparedContext, with: let rowDescription) :
@@ -357,22 +368,26 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
357
368
hash2. append ( salt. 3 )
358
369
let hash = " md5 " + Insecure. MD5. hash ( data: hash2) . hexdigest ( )
359
370
360
- context. writeAndFlush ( . password( . init( value: hash) ) , promise: nil )
371
+ try ! self . encoder. encode ( . password( . init( value: hash) ) )
372
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
373
+
361
374
case . cleartext:
362
- context. writeAndFlush ( . password( . init( value: authContext. password ?? " " ) ) , promise: nil )
375
+ try ! self . encoder. encode ( . password( . init( value: authContext. password ?? " " ) ) )
376
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
363
377
}
364
378
}
365
379
366
380
private func sendCloseAndSyncMessage( _ sendClose: CloseTarget , context: ChannelHandlerContext ) {
367
381
switch sendClose {
368
382
case . preparedStatement( let name) :
369
- context. write ( . close( . preparedStatement( name) ) , promise: nil )
370
- context. write ( . sync, promise: nil )
371
- context. flush ( )
383
+ try ! self . encoder. encode ( . close( . preparedStatement( name) ) )
384
+ try ! self . encoder. encode ( . sync)
385
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
386
+
372
387
case . portal( let name) :
373
- context . write ( . close( . portal( name) ) , promise : nil )
374
- context . write ( . sync, promise : nil )
375
- context. flush ( )
388
+ try ! self . encoder . encode ( . close( . portal( name) ) )
389
+ try ! self . encoder . encode ( . sync)
390
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder . flush ( ) ! ) , promise : nil )
376
391
}
377
392
}
378
393
@@ -387,10 +402,16 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
387
402
query: query,
388
403
parameters: [ ] )
389
404
390
- context. write ( . parse( parse) , promise: nil )
391
- context. write ( . describe( . preparedStatement( statementName) ) , promise: nil )
392
- context. write ( . sync, promise: nil )
393
- context. flush ( )
405
+
406
+ do {
407
+ try self . encoder. encode ( . parse( parse) )
408
+ try self . encoder. encode ( . describe( . preparedStatement( statementName) ) )
409
+ try self . encoder. encode ( . sync)
410
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
411
+ } catch {
412
+ let action = self . state. errorHappened ( . channel( underlying: error) )
413
+ self . run ( action, with: context)
414
+ }
394
415
}
395
416
396
417
private func sendBindExecuteAndSyncMessage(
@@ -403,10 +424,15 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
403
424
preparedStatementName: statementName,
404
425
parameters: binds)
405
426
406
- context. write ( . bind( bind) , promise: nil )
407
- context. write ( . execute( . init( portalName: " " ) ) , promise: nil )
408
- context. write ( . sync, promise: nil )
409
- context. flush ( )
427
+ do {
428
+ try self . encoder. encode ( . bind( bind) )
429
+ try self . encoder. encode ( . execute( . init( portalName: " " ) ) )
430
+ try self . encoder. encode ( . sync)
431
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
432
+ } catch {
433
+ let action = self . state. errorHappened ( . channel( underlying: error) )
434
+ self . run ( action, with: context)
435
+ }
410
436
}
411
437
412
438
private func sendParseDescribeBindExecuteAndSyncMessage(
@@ -424,12 +450,17 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
424
450
preparedStatementName: unnamedStatementName,
425
451
parameters: binds)
426
452
427
- context. write ( wrapOutboundOut ( . parse( parse) ) , promise: nil )
428
- context. write ( wrapOutboundOut ( . describe( . preparedStatement( " " ) ) ) , promise: nil )
429
- context. write ( wrapOutboundOut ( . bind( bind) ) , promise: nil )
430
- context. write ( wrapOutboundOut ( . execute( . init( portalName: " " ) ) ) , promise: nil )
431
- context. write ( wrapOutboundOut ( . sync) , promise: nil )
432
- context. flush ( )
453
+ do {
454
+ try self . encoder. encode ( . parse( parse) )
455
+ try self . encoder. encode ( . describe( . preparedStatement( " " ) ) )
456
+ try self . encoder. encode ( . bind( bind) )
457
+ try self . encoder. encode ( . execute( . init( portalName: " " ) ) )
458
+ try self . encoder. encode ( . sync)
459
+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
460
+ } catch {
461
+ let action = self . state. errorHappened ( . channel( underlying: error) )
462
+ self . run ( action, with: context)
463
+ }
433
464
}
434
465
435
466
private func succeedQueryWithRowStream(
@@ -503,16 +534,6 @@ extension PSQLChannelHandler: PSQLRowsDataSource {
503
534
}
504
535
}
505
536
506
- extension ChannelHandlerContext {
507
- func write( _ psqlMessage: PSQLFrontendMessage , promise: EventLoopPromise < Void > ? = nil ) {
508
- self . write ( NIOAny ( psqlMessage) , promise: promise)
509
- }
510
-
511
- func writeAndFlush( _ psqlMessage: PSQLFrontendMessage , promise: EventLoopPromise < Void > ? = nil ) {
512
- self . writeAndFlush ( NIOAny ( psqlMessage) , promise: promise)
513
- }
514
- }
515
-
516
537
extension PSQLConnection . Configuration . Authentication {
517
538
func toAuthContext( ) -> AuthContext {
518
539
AuthContext (
0 commit comments