Skip to content

Commit 28ab2df

Browse files
authored
Move Decoding into PSQLChannelHandler (#182)
1 parent 6611ee1 commit 28ab2df

File tree

5 files changed

+111
-57
lines changed

5 files changed

+111
-57
lines changed

Sources/PostgresNIO/New/PSQLChannelHandler.swift

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject {
88
}
99

1010
final class PSQLChannelHandler: ChannelDuplexHandler {
11-
typealias InboundIn = PSQLBackendMessage
1211
typealias OutboundIn = PSQLTask
12+
typealias InboundIn = ByteBuffer
1313
typealias OutboundOut = PSQLFrontendMessage
1414

1515
private let logger: Logger
@@ -24,6 +24,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
2424
/// The context is captured in `handlerAdded` and released` in `handlerRemoved`
2525
private var handlerContext: ChannelHandlerContext!
2626
private var rowStream: PSQLRowStream?
27+
private var decoder: NIOSingleStepByteToMessageProcessor<PSQLBackendMessageDecoder>
2728
private let authentificationConfiguration: PSQLConnection.Configuration.Authentication?
2829
private let configureSSLCallback: ((Channel) throws -> Void)?
2930

@@ -38,6 +39,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
3839
self.authentificationConfiguration = authentification
3940
self.configureSSLCallback = configureSSLCallback
4041
self.logger = logger
42+
self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder())
4143
}
4244

4345
#if DEBUG
@@ -51,6 +53,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
5153
self.authentificationConfiguration = authentification
5254
self.configureSSLCallback = configureSSLCallback
5355
self.logger = logger
56+
self.decoder = NIOSingleStepByteToMessageProcessor(PSQLBackendMessageDecoder())
5457
}
5558
#endif
5659

@@ -91,54 +94,62 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
9194
}
9295

9396
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
94-
let incomingMessage = self.unwrapInboundIn(data)
97+
let buffer = self.unwrapInboundIn(data)
9598

96-
self.logger.trace("Backend message received", metadata: [.message: "\(incomingMessage)"])
97-
98-
let action: ConnectionStateMachine.ConnectionAction
99-
100-
switch incomingMessage {
101-
case .authentication(let authentication):
102-
action = self.state.authenticationMessageReceived(authentication)
103-
case .backendKeyData(let keyData):
104-
action = self.state.backendKeyDataReceived(keyData)
105-
case .bindComplete:
106-
action = self.state.bindCompleteReceived()
107-
case .closeComplete:
108-
action = self.state.closeCompletedReceived()
109-
case .commandComplete(let commandTag):
110-
action = self.state.commandCompletedReceived(commandTag)
111-
case .dataRow(let dataRow):
112-
action = self.state.dataRowReceived(dataRow)
113-
case .emptyQueryResponse:
114-
action = self.state.emptyQueryResponseReceived()
115-
case .error(let errorResponse):
116-
action = self.state.errorReceived(errorResponse)
117-
case .noData:
118-
action = self.state.noDataReceived()
119-
case .notice(let noticeResponse):
120-
action = self.state.noticeReceived(noticeResponse)
121-
case .notification(let notification):
122-
action = self.state.notificationReceived(notification)
123-
case .parameterDescription(let parameterDescription):
124-
action = self.state.parameterDescriptionReceived(parameterDescription)
125-
case .parameterStatus(let parameterStatus):
126-
action = self.state.parameterStatusReceived(parameterStatus)
127-
case .parseComplete:
128-
action = self.state.parseCompleteReceived()
129-
case .portalSuspended:
130-
action = self.state.portalSuspendedReceived()
131-
case .readyForQuery(let transactionState):
132-
action = self.state.readyForQueryReceived(transactionState)
133-
case .rowDescription(let rowDescription):
134-
action = self.state.rowDescriptionReceived(rowDescription)
135-
case .sslSupported:
136-
action = self.state.sslSupportedReceived()
137-
case .sslUnsupported:
138-
action = self.state.sslUnsupportedReceived()
99+
do {
100+
try self.decoder.process(buffer: buffer) { message in
101+
self.logger.trace("Backend message received", metadata: [.message: "\(message)"])
102+
let action: ConnectionStateMachine.ConnectionAction
103+
104+
switch message {
105+
case .authentication(let authentication):
106+
action = self.state.authenticationMessageReceived(authentication)
107+
case .backendKeyData(let keyData):
108+
action = self.state.backendKeyDataReceived(keyData)
109+
case .bindComplete:
110+
action = self.state.bindCompleteReceived()
111+
case .closeComplete:
112+
action = self.state.closeCompletedReceived()
113+
case .commandComplete(let commandTag):
114+
action = self.state.commandCompletedReceived(commandTag)
115+
case .dataRow(let dataRow):
116+
action = self.state.dataRowReceived(dataRow)
117+
case .emptyQueryResponse:
118+
action = self.state.emptyQueryResponseReceived()
119+
case .error(let errorResponse):
120+
action = self.state.errorReceived(errorResponse)
121+
case .noData:
122+
action = self.state.noDataReceived()
123+
case .notice(let noticeResponse):
124+
action = self.state.noticeReceived(noticeResponse)
125+
case .notification(let notification):
126+
action = self.state.notificationReceived(notification)
127+
case .parameterDescription(let parameterDescription):
128+
action = self.state.parameterDescriptionReceived(parameterDescription)
129+
case .parameterStatus(let parameterStatus):
130+
action = self.state.parameterStatusReceived(parameterStatus)
131+
case .parseComplete:
132+
action = self.state.parseCompleteReceived()
133+
case .portalSuspended:
134+
action = self.state.portalSuspendedReceived()
135+
case .readyForQuery(let transactionState):
136+
action = self.state.readyForQueryReceived(transactionState)
137+
case .rowDescription(let rowDescription):
138+
action = self.state.rowDescriptionReceived(rowDescription)
139+
case .sslSupported:
140+
action = self.state.sslSupportedReceived()
141+
case .sslUnsupported:
142+
action = self.state.sslUnsupportedReceived()
143+
}
144+
145+
self.run(action, with: context)
146+
}
147+
} catch let error as PSQLDecodingError {
148+
let action = self.state.errorHappened(.decoding(error))
149+
self.run(action, with: context)
150+
} catch {
151+
preconditionFailure("Expected to only get PSQLDecodingErrors from the PSQLBackendMessageDecoder.")
139152
}
140-
141-
self.run(action, with: context)
142153
}
143154

144155
func channelReadComplete(context: ChannelHandlerContext) {

Sources/PostgresNIO/New/PSQLConnection.swift

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,6 @@ final class PSQLConnection {
214214
}.flatMap { address -> EventLoopFuture<Channel> in
215215
let bootstrap = ClientBootstrap(group: eventLoop)
216216
.channelInitializer { channel in
217-
let decoder = ByteToMessageHandler(PSQLBackendMessageDecoder())
218-
219217
var configureSSLCallback: ((Channel) throws -> ())? = nil
220218
if let tlsConfiguration = configuration.tlsConfiguration {
221219
configureSSLCallback = { channel in
@@ -225,12 +223,11 @@ final class PSQLConnection {
225223
let sslHandler = try NIOSSLClientHandler(
226224
context: sslContext,
227225
serverHostname: configuration.sslServerHostname)
228-
try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(decoder))
226+
try channel.pipeline.syncOperations.addHandler(sslHandler, position: .first)
229227
}
230228
}
231229

232230
return channel.pipeline.addHandlers([
233-
decoder,
234231
MessageToByteHandler(PSQLFrontendMessageEncoder(jsonEncoder: configuration.coders.jsonEncoder)),
235232
PSQLChannelHandler(
236233
authentification: configuration.authentication,
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import NIOCore
2+
3+
/// This is a reverse ``NIOCore/ByteToMessageHandler``. Instead of creating messages from incoming bytes
4+
/// as the normal `ByteToMessageHandler` does, this `ReverseByteToMessageHandler` creates messages
5+
/// from outgoing bytes. This is only important for testing in `EmbeddedChannel`s.
6+
class ReverseMessageToByteHandler<Encoder: MessageToByteEncoder>: ChannelInboundHandler {
7+
typealias InboundIn = Encoder.OutboundIn
8+
typealias InboundOut = ByteBuffer
9+
10+
var byteBuffer: ByteBuffer!
11+
let encoder: Encoder
12+
13+
init(_ encoder: Encoder) {
14+
self.encoder = encoder
15+
}
16+
17+
func handlerAdded(context: ChannelHandlerContext) {
18+
self.byteBuffer = context.channel.allocator.buffer(capacity: 128)
19+
}
20+
21+
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
22+
let message = self.unwrapInboundIn(data)
23+
24+
do {
25+
self.byteBuffer.clear()
26+
try self.encoder.encode(data: message, out: &self.byteBuffer)
27+
context.fireChannelRead(self.wrapInboundOut(self.byteBuffer))
28+
} catch {
29+
context.fireErrorCaught(error)
30+
}
31+
}
32+
}

Tests/PostgresNIOTests/New/PSQLChannelHandlerTests.swift

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ class PSQLChannelHandlerTests: XCTestCase {
1111

1212
func testHandlerAddedWithoutSSL() {
1313
let config = self.testConnectionConfiguration()
14-
let handler = PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil)
15-
let embedded = EmbeddedChannel(handler: handler)
14+
let embedded = EmbeddedChannel(handlers: [
15+
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
16+
PSQLChannelHandler(authentification: config.authentication, configureSSLCallback: nil)
17+
])
1618
defer { XCTAssertNoThrow(try embedded.finish()) }
1719

1820
var maybeMessage: PSQLFrontendMessage?
@@ -39,7 +41,10 @@ class PSQLChannelHandlerTests: XCTestCase {
3941
let handler = PSQLChannelHandler(authentification: config.authentication) { channel in
4042
addSSLCallbackIsHit = true
4143
}
42-
let embedded = EmbeddedChannel(handler: handler)
44+
let embedded = EmbeddedChannel(handlers: [
45+
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
46+
handler
47+
])
4348

4449
var maybeMessage: PSQLFrontendMessage?
4550
XCTAssertNoThrow(embedded.connect(to: try .init(ipAddress: "0.0.0.0", port: 5432), promise: nil))
@@ -78,7 +83,10 @@ class PSQLChannelHandlerTests: XCTestCase {
7883
XCTFail("This callback should never be exectuded")
7984
throw PSQLError.sslUnsupported
8085
}
81-
let embedded = EmbeddedChannel(handler: handler)
86+
let embedded = EmbeddedChannel(handlers: [
87+
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
88+
handler
89+
])
8290
let eventHandler = TestEventHandler()
8391
XCTAssertNoThrow(try embedded.pipeline.addHandler(eventHandler, position: .last).wait())
8492

@@ -107,7 +115,10 @@ class PSQLChannelHandlerTests: XCTestCase {
107115
)
108116
let state = ConnectionStateMachine(.waitingToStartAuthentication)
109117
let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil)
110-
let embedded = EmbeddedChannel(handler: handler)
118+
let embedded = EmbeddedChannel(handlers: [
119+
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
120+
handler
121+
])
111122

112123
embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil)
113124
XCTAssertEqual(try embedded.readOutbound(as: PSQLFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters())))
@@ -132,7 +143,10 @@ class PSQLChannelHandlerTests: XCTestCase {
132143
)
133144
let state = ConnectionStateMachine(.waitingToStartAuthentication)
134145
let handler = PSQLChannelHandler(authentification: config.authentication, state: state, configureSSLCallback: nil)
135-
let embedded = EmbeddedChannel(handler: handler)
146+
let embedded = EmbeddedChannel(handlers: [
147+
ReverseMessageToByteHandler(PSQLBackendMessageEncoder()),
148+
handler
149+
])
136150

137151
embedded.triggerUserOutboundEvent(PSQLOutgoingEvent.authenticate(authContext), promise: nil)
138152
XCTAssertEqual(try embedded.readOutbound(as: PSQLFrontendMessage.self), .startup(.versionThree(parameters: authContext.toStartupParameters())))

0 commit comments

Comments
 (0)