Skip to content

Commit 131deb3

Browse files
authored
Move message encoding into PSQLChannelHandler (#181)
1 parent 28ab2df commit 131deb3

File tree

4 files changed

+80
-54
lines changed

4 files changed

+80
-54
lines changed

Sources/PostgresNIO/New/PSQLChannelHandler.swift

Lines changed: 65 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject {
1010
final class PSQLChannelHandler: ChannelDuplexHandler {
1111
typealias OutboundIn = PSQLTask
1212
typealias InboundIn = ByteBuffer
13-
typealias OutboundOut = PSQLFrontendMessage
13+
typealias OutboundOut = ByteBuffer
1414

1515
private let logger: Logger
1616
private var state: ConnectionStateMachine {
@@ -25,32 +25,33 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
2525
private var handlerContext: ChannelHandlerContext!
2626
private var rowStream: PSQLRowStream?
2727
private var decoder: NIOSingleStepByteToMessageProcessor<PSQLBackendMessageDecoder>
28-
private let authentificationConfiguration: PSQLConnection.Configuration.Authentication?
28+
private var encoder: BufferedMessageEncoder<PSQLFrontendMessageEncoder>!
29+
private let configuration: PSQLConnection.Configuration
2930
private let configureSSLCallback: ((Channel) throws -> Void)?
3031

3132
/// this delegate should only be accessed on the connections `EventLoop`
3233
weak var notificationDelegate: PSQLChannelHandlerNotificationDelegate?
3334

34-
init(authentification: PSQLConnection.Configuration.Authentication?,
35+
init(configuration: PSQLConnection.Configuration,
3536
logger: Logger,
3637
configureSSLCallback: ((Channel) throws -> Void)?)
3738
{
3839
self.state = ConnectionStateMachine()
39-
self.authentificationConfiguration = authentification
40+
self.configuration = configuration
4041
self.configureSSLCallback = configureSSLCallback
4142
self.logger = logger
4243
self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder())
4344
}
4445

4546
#if DEBUG
4647
/// for testing purposes only
47-
init(authentification: PSQLConnection.Configuration.Authentication?,
48+
init(configuration: PSQLConnection.Configuration,
4849
state: ConnectionStateMachine = .init(.initialized),
4950
logger: Logger = .psqlNoOpLogger,
5051
configureSSLCallback: ((Channel) throws -> Void)?)
5152
{
5253
self.state = state
53-
self.authentificationConfiguration = authentification
54+
self.configuration = configuration
5455
self.configureSSLCallback = configureSSLCallback
5556
self.logger = logger
5657
self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder())
@@ -61,6 +62,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
6162

6263
func handlerAdded(context: ChannelHandlerContext) {
6364
self.handlerContext = context
65+
self.encoder = BufferedMessageEncoder(
66+
buffer: context.channel.allocator.buffer(capacity: 256),
67+
encoder: PSQLFrontendMessageEncoder(jsonEncoder: self.configuration.coders.jsonEncoder)
68+
)
69+
6470
if context.channel.isActive {
6571
self.connected(context: context)
6672
}
@@ -222,15 +228,19 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
222228
case .wait:
223229
break
224230
case .sendStartupMessage(let authContext):
225-
context.writeAndFlush(.startup(.versionThree(parameters: authContext.toStartupParameters())), promise: nil)
231+
try! self.encoder.encode(.startup(.versionThree(parameters: authContext.toStartupParameters())))
232+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
226233
case .sendSSLRequest:
227-
context.writeAndFlush(.sslRequest(.init()), promise: nil)
234+
try! self.encoder.encode(.sslRequest(.init()))
235+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
228236
case .sendPasswordMessage(let mode, let authContext):
229237
self.sendPasswordMessage(mode: mode, authContext: authContext, context: context)
230238
case .sendSaslInitialResponse(let name, let initialResponse):
231-
context.writeAndFlush(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse)))
239+
try! self.encoder.encode(.saslInitialResponse(.init(saslMechanism: name, initialData: initialResponse)))
240+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
232241
case .sendSaslResponse(let bytes):
233-
context.writeAndFlush(.saslResponse(.init(data: bytes)))
242+
try! self.encoder.encode(.saslResponse(.init(data: bytes)))
243+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
234244
case .closeConnectionAndCleanup(let cleanupContext):
235245
self.closeConnectionAndCleanup(cleanupContext, context: context)
236246
case .fireChannelInactive:
@@ -277,7 +287,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
277287
case .provideAuthenticationContext:
278288
context.fireUserInboundEventTriggered(PSQLEvent.readyForStartup)
279289

280-
if let authentication = self.authentificationConfiguration {
290+
if let authentication = self.configuration.authentication {
281291
let authContext = AuthContext(
282292
username: authentication.username,
283293
password: authentication.password,
@@ -293,7 +303,8 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
293303
// The normal, graceful termination procedure is that the frontend sends a Terminate
294304
// message and immediately closes the connection. On receipt of this message, the
295305
// backend closes the connection and terminates.
296-
context.write(.terminate, promise: nil)
306+
try! self.encoder.encode(.terminate)
307+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
297308
}
298309
context.close(mode: .all, promise: promise)
299310
case .succeedPreparedStatementCreation(let preparedContext, with: let rowDescription):
@@ -357,22 +368,26 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
357368
hash2.append(salt.3)
358369
let hash = "md5" + Insecure.MD5.hash(data: hash2).hexdigest()
359370

360-
context.writeAndFlush(.password(.init(value: hash)), promise: nil)
371+
try! self.encoder.encode(.password(.init(value: hash)))
372+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
373+
361374
case .cleartext:
362-
context.writeAndFlush(.password(.init(value: authContext.password ?? "")), promise: nil)
375+
try! self.encoder.encode(.password(.init(value: authContext.password ?? "")))
376+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
363377
}
364378
}
365379

366380
private func sendCloseAndSyncMessage(_ sendClose: CloseTarget, context: ChannelHandlerContext) {
367381
switch sendClose {
368382
case .preparedStatement(let name):
369-
context.write(.close(.preparedStatement(name)), promise: nil)
370-
context.write(.sync, promise: nil)
371-
context.flush()
383+
try! self.encoder.encode(.close(.preparedStatement(name)))
384+
try! self.encoder.encode(.sync)
385+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
386+
372387
case .portal(let name):
373-
context.write(.close(.portal(name)), promise: nil)
374-
context.write(.sync, promise: nil)
375-
context.flush()
388+
try! self.encoder.encode(.close(.portal(name)))
389+
try! self.encoder.encode(.sync)
390+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
376391
}
377392
}
378393

@@ -387,10 +402,16 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
387402
query: query,
388403
parameters: [])
389404

390-
context.write(.parse(parse), promise: nil)
391-
context.write(.describe(.preparedStatement(statementName)), promise: nil)
392-
context.write(.sync, promise: nil)
393-
context.flush()
405+
406+
do {
407+
try self.encoder.encode(.parse(parse))
408+
try self.encoder.encode(.describe(.preparedStatement(statementName)))
409+
try self.encoder.encode(.sync)
410+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
411+
} catch {
412+
let action = self.state.errorHappened(.channel(underlying: error))
413+
self.run(action, with: context)
414+
}
394415
}
395416

396417
private func sendBindExecuteAndSyncMessage(
@@ -403,10 +424,15 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
403424
preparedStatementName: statementName,
404425
parameters: binds)
405426

406-
context.write(.bind(bind), promise: nil)
407-
context.write(.execute(.init(portalName: "")), promise: nil)
408-
context.write(.sync, promise: nil)
409-
context.flush()
427+
do {
428+
try self.encoder.encode(.bind(bind))
429+
try self.encoder.encode(.execute(.init(portalName: "")))
430+
try self.encoder.encode(.sync)
431+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
432+
} catch {
433+
let action = self.state.errorHappened(.channel(underlying: error))
434+
self.run(action, with: context)
435+
}
410436
}
411437

412438
private func sendParseDescribeBindExecuteAndSyncMessage(
@@ -424,12 +450,17 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
424450
preparedStatementName: unnamedStatementName,
425451
parameters: binds)
426452

427-
context.write(wrapOutboundOut(.parse(parse)), promise: nil)
428-
context.write(wrapOutboundOut(.describe(.preparedStatement(""))), promise: nil)
429-
context.write(wrapOutboundOut(.bind(bind)), promise: nil)
430-
context.write(wrapOutboundOut(.execute(.init(portalName: ""))), promise: nil)
431-
context.write(wrapOutboundOut(.sync), promise: nil)
432-
context.flush()
453+
do {
454+
try self.encoder.encode(.parse(parse))
455+
try self.encoder.encode(.describe(.preparedStatement("")))
456+
try self.encoder.encode(.bind(bind))
457+
try self.encoder.encode(.execute(.init(portalName: "")))
458+
try self.encoder.encode(.sync)
459+
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flush()!), promise: nil)
460+
} catch {
461+
let action = self.state.errorHappened(.channel(underlying: error))
462+
self.run(action, with: context)
463+
}
433464
}
434465

435466
private func succeedQueryWithRowStream(
@@ -503,16 +534,6 @@ extension PSQLChannelHandler: PSQLRowsDataSource {
503534
}
504535
}
505536

506-
extension ChannelHandlerContext {
507-
func write(_ psqlMessage: PSQLFrontendMessage, promise: EventLoopPromise<Void>? = nil) {
508-
self.write(NIOAny(psqlMessage), promise: promise)
509-
}
510-
511-
func writeAndFlush(_ psqlMessage: PSQLFrontendMessage, promise: EventLoopPromise<Void>? = nil) {
512-
self.writeAndFlush(NIOAny(psqlMessage), promise: promise)
513-
}
514-
}
515-
516537
extension PSQLConnection.Configuration.Authentication {
517538
func toAuthContext() -> AuthContext {
518539
AuthContext(

Sources/PostgresNIO/New/PSQLConnection.swift

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,8 @@ final class PSQLConnection {
228228
}
229229

230230
return channel.pipeline.addHandlers([
231-
MessageToByteHandler(PSQLFrontendMessageEncoder(jsonEncoder: configuration.coders.jsonEncoder)),
232231
PSQLChannelHandler(
233-
authentification: configuration.authentication,
232+
configuration: configuration,
234233
logger: logger,
235234
configureSSLCallback: configureSSLCallback),
236235
PSQLEventsHandler(logger: logger)

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
2020
return nil
2121
}
2222

23-
guard var messageSlice = buffer.getSlice(at: buffer.readerIndex &+ 4, length: Int(length)) else {
23+
guard var messageSlice = buffer.getSlice(at: buffer.readerIndex + 4, length: Int(length) - 4) else {
2424
return nil
2525
}
26-
buffer.moveReaderIndex(forwardBy: 4 &+ Int(length))
26+
buffer.moveReaderIndex(to: Int(length))
2727
let finalIndex = buffer.readerIndex
2828

29-
guard let code = buffer.readInteger(as: UInt32.self) else {
29+
guard let code = messageSlice.readInteger(as: UInt32.self) else {
3030
throw PSQLPartialDecodingError.fieldNotDecodable(type: UInt32.self)
3131
}
3232

Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ class PSQLChannelHandlerTests: XCTestCase {
1111

1212
func testHandlerAddedWithoutSSL() {
1313
let config = self.testConnectionConfiguration()
14+
let handler = PSQLChannelHandler(configuration: config, configureSSLCallback: nil)
1415
let embedded = EmbeddedChannel(handlers: [
16+
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
1517
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
16-
PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil)
18+
handler
1719
])
1820
defer { XCTAssertNoThrow(try embedded.finish()) }
1921

@@ -38,10 +40,11 @@ class PSQLChannelHandlerTests: XCTestCase {
3840
var config = self.testConnectionConfiguration()
3941
config.tlsConfiguration = .makeClientConfiguration()
4042
var addSSLCallbackIsHit = false
41-
let handler = PSQLChannelHandler(authentification: config.authentication) { channel in
43+
let handler = PSQLChannelHandler(configuration: config) { channel in
4244
addSSLCallbackIsHit = true
4345
}
4446
let embedded = EmbeddedChannel(handlers: [
47+
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
4548
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
4649
handler
4750
])
@@ -79,11 +82,12 @@ class PSQLChannelHandlerTests: XCTestCase {
7982
var config = self.testConnectionConfiguration()
8083
config.tlsConfiguration = .makeClientConfiguration()
8184

82-
let handler = PSQLChannelHandler(authentification: config.authentication) { channel in
85+
let handler = PSQLChannelHandler(configuration: config) { channel in
8386
XCTFail("This callback should never be exectuded")
8487
throw PSQLError.sslUnsupported
8588
}
8689
let embedded = EmbeddedChannel(handlers: [
90+
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
8791
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
8892
handler
8993
])
@@ -114,8 +118,9 @@ class PSQLChannelHandlerTests: XCTestCase {
114118
database: config.authentication?.database
115119
)
116120
let state = ConnectionStateMachine(.waitingToStartAuthentication)
117-
let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil)
121+
let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil)
118122
let embedded = EmbeddedChannel(handlers: [
123+
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
119124
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
120125
handler
121126
])
@@ -142,8 +147,9 @@ class PSQLChannelHandlerTests: XCTestCase {
142147
database: config.authentication?.database
143148
)
144149
let state = ConnectionStateMachine(.waitingToStartAuthentication)
145-
let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil)
150+
let handler = PSQLChannelHandler(configuration: config, state: state, configureSSLCallback: nil)
146151
let embedded = EmbeddedChannel(handlers: [
152+
ReverseByteToMessageHandler(PSQLFrontendMessageDecoder()),
147153
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
148154
handler
149155
])

0 commit comments

Comments
 (0)