diff --git a/Sources/GRPC/GRPCServerPipelineConfigurator.swift b/Sources/GRPC/GRPCServerPipelineConfigurator.swift new file mode 100644 index 000000000..ba30b00a5 --- /dev/null +++ b/Sources/GRPC/GRPCServerPipelineConfigurator.swift @@ -0,0 +1,389 @@ +/* + * Copyright 2020, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import Logging +import NIO +import NIOHTTP1 +import NIOHTTP2 +import NIOTLS + +/// Configures a server pipeline for gRPC with the appropriate handlers depending on the HTTP +/// version used for transport. +/// +/// If TLS is enabled then the handler listens for an 'TLSUserEvent.handshakeCompleted' event and +/// configures the pipeline appropriately for the protocol negotiated via ALPN. If TLS is not +/// configured then the HTTP version is determined by parsing the inbound byte stream. +final class GRPCServerPipelineConfigurator: ChannelInboundHandler, RemovableChannelHandler { + internal typealias InboundIn = ByteBuffer + internal typealias InboundOut = ByteBuffer + + /// The server configuration. + private let configuration: Server.Configuration + + /// Reads which we're holding on to before the pipeline is configured. + private var bufferedReads = CircularBuffer() + + /// The current state. + private var state: State + + private enum ALPN { + /// ALPN is expected. It may or may not be required, however. + case expected(required: Bool) + + /// ALPN was expected but not required and no protocol was negotiated in the handshake. We may + /// now fall back to parsing bytes on the connection. + case expectedButFallingBack + + /// ALPN is not expected; this is a cleartext connection. + case notExpected + } + + private enum State { + /// The pipeline isn't configured yet. + case notConfigured(alpn: ALPN) + /// We're configuring the pipeline. + case configuring + } + + init(configuration: Server.Configuration) { + if let tls = configuration.tls { + self.state = .notConfigured(alpn: .expected(required: tls.requireALPN)) + } else { + self.state = .notConfigured(alpn: .notExpected) + } + + self.configuration = configuration + } + + /// Makes a gRPC Server keepalive handler. + private func makeKeepaliveHandler() -> GRPCServerKeepaliveHandler { + return .init(configuration: self.configuration.connectionKeepalive) + } + + /// Makes a gRPC idle handler for the server.. + private func makeIdleHandler() -> GRPCIdleHandler { + return .init( + mode: .server, + logger: self.configuration.logger, + idleTimeout: self.configuration.connectionIdleTimeout + ) + } + + /// Makes an HTTP/2 handler. + private func makeHTTP2Handler() -> NIOHTTP2Handler { + return .init(mode: .server) + } + + /// Makes an HTTP/2 multiplexer suitable handling gRPC requests. + private func makeHTTP2Multiplexer(for channel: Channel) -> HTTP2StreamMultiplexer { + var logger = self.configuration.logger + + return .init( + mode: .server, + channel: channel, + targetWindowSize: self.configuration.httpTargetWindowSize + ) { stream in + stream.getOption(HTTP2StreamChannelOptions.streamID).map { streamID -> Logger in + logger[metadataKey: MetadataKey.h2StreamID] = "\(streamID)" + return logger + }.recover { _ in + logger[metadataKey: MetadataKey.h2StreamID] = "" + return logger + }.flatMap { logger in + // TODO: provide user configuration for header normalization. + let handler = self.makeHTTP2ToRawGRPCHandler(normalizeHeaders: true, logger: logger) + return stream.pipeline.addHandler(handler) + } + } + } + + /// Makes an HTTP/2 to raw gRPC server handler. + private func makeHTTP2ToRawGRPCHandler( + normalizeHeaders: Bool, + logger: Logger + ) -> HTTP2ToRawGRPCServerCodec { + return HTTP2ToRawGRPCServerCodec( + servicesByName: self.configuration.serviceProvidersByName, + encoding: self.configuration.messageEncoding, + errorDelegate: self.configuration.errorDelegate, + normalizeHeaders: normalizeHeaders, + logger: logger + ) + } + + /// The pipeline finished configuring. + private func configurationCompleted(result: Result, context: ChannelHandlerContext) { + switch result { + case .success: + context.pipeline.removeHandler(context: context, promise: nil) + case let .failure(error): + self.errorCaught(context: context, error: error) + } + } + + /// Configures the pipeline to handle gRPC requests on an HTTP/2 connection. + private func configureHTTP2(context: ChannelHandlerContext) { + // We're now configuring the pipeline. + self.state = .configuring + + // We could use 'Channel.configureHTTP2Pipeline', but then we'd have to find the right handlers + // to then insert our keepalive and idle handlers between. We can just add everything together. + var handlers: [ChannelHandler] = [] + handlers.reserveCapacity(4) + handlers.append(self.makeHTTP2Handler()) + handlers.append(self.makeKeepaliveHandler()) + handlers.append(self.makeIdleHandler()) + handlers.append(self.makeHTTP2Multiplexer(for: context.channel)) + + // Now configure the pipeline with the handlers. + context.channel.pipeline.addHandlers(handlers).whenComplete { result in + self.configurationCompleted(result: result, context: context) + } + } + + /// Configures the pipeline to handle gRPC-Web requests on an HTTP/1 connection. + private func configureHTTP1(context: ChannelHandlerContext) { + // We're now configuring the pipeline. + self.state = .configuring + + context.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { + context.pipeline.addHandlers([ + WebCORSHandler(), + GRPCWebToHTTP2ServerCodec(scheme: self.configuration.tls == nil ? "http" : "https"), + // There's no need to normalize headers for HTTP/1. + self.makeHTTP2ToRawGRPCHandler(normalizeHeaders: false, logger: self.configuration.logger), + ]) + }.whenComplete { result in + self.configurationCompleted(result: result, context: context) + } + } + + /// Attempts to determine the HTTP version from the buffer and then configure the pipeline + /// appropriately. Closes the connection if the HTTP version could not be determined. + private func determineHTTPVersionAndConfigurePipeline( + buffer: ByteBuffer, + context: ChannelHandlerContext + ) { + if HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer) { + self.configureHTTP2(context: context) + } else if HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer) { + self.configureHTTP1(context: context) + } else { + self.configuration.logger.error("Unable to determine http version, closing") + context.close(mode: .all, promise: nil) + } + } + + /// Handles a 'TLSUserEvent.handshakeCompleted' event and configures the pipeline to handle gRPC + /// requests. + private func handleHandshakeCompletedEvent( + _ event: TLSUserEvent, + alpnIsRequired: Bool, + context: ChannelHandlerContext + ) { + switch event { + case let .handshakeCompleted(negotiatedProtocol): + self.configuration.logger.debug("TLS handshake completed", metadata: [ + "alpn": "\(negotiatedProtocol ?? "nil")", + ]) + + switch negotiatedProtocol { + case let .some(negotiated): + if GRPCApplicationProtocolIdentifier.isHTTP2Like(negotiated) { + self.configureHTTP2(context: context) + } else if GRPCApplicationProtocolIdentifier.isHTTP1(negotiated) { + self.configureHTTP1(context: context) + } else { + self.configuration.logger.warning("Unsupported ALPN identifier '\(negotiated)', closing") + context.close(mode: .all, promise: nil) + } + + case .none: + if alpnIsRequired { + self.configuration.logger.warning("No ALPN protocol negotiated, closing'") + context.close(mode: .all, promise: nil) + } else { + self.configuration.logger.warning("No ALPN protocol negotiated'") + // We're now falling back to parsing bytes. + self.state = .notConfigured(alpn: .expectedButFallingBack) + self.tryParsingBufferedData(context: context) + } + } + + case .shutdownCompleted: + // We don't care about this here. + () + } + } + + /// Try to parse the buffered data to determine whether or not HTTP/2 or HTTP/1 should be used. + private func tryParsingBufferedData(context: ChannelHandlerContext) { + guard let first = self.bufferedReads.first else { + // No data buffered yet. We'll try when we read. + return + } + + let buffer = self.unwrapInboundIn(first) + self.determineHTTPVersionAndConfigurePipeline(buffer: buffer, context: context) + } + + // MARK: - Channel Handler + + internal func errorCaught(context: ChannelHandlerContext, error: Error) { + if let delegate = self.configuration.errorDelegate { + let baseError: Error + + if let errorWithContext = error as? GRPCError.WithContext { + baseError = errorWithContext.error + } else { + baseError = error + } + + delegate.observeLibraryError(baseError) + } + + context.close(mode: .all, promise: nil) + } + + internal func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + switch self.state { + case let .notConfigured(alpn: .expected(required)): + if let event = event as? TLSUserEvent { + self.handleHandshakeCompletedEvent(event, alpnIsRequired: required, context: context) + } + + case .notConfigured(alpn: .expectedButFallingBack), + .notConfigured(alpn: .notExpected), + .configuring: + () + } + + context.fireUserInboundEventTriggered(event) + } + + internal func channelRead(context: ChannelHandlerContext, data: NIOAny) { + self.bufferedReads.append(data) + + switch self.state { + case .notConfigured(alpn: .notExpected), + .notConfigured(alpn: .expectedButFallingBack): + // If ALPN isn't expected, or we didn't negotiate via ALPN and we don't require it then we + // can try parsing the data we just buffered. + self.tryParsingBufferedData(context: context) + + case .notConfigured(alpn: .expected), + .configuring: + // We expect ALPN or we're being configured, just buffer the data, we'll forward it later. + () + } + + // Don't forward the reads: we'll do so when we have configured the pipeline. + } + + internal func removeHandler( + context: ChannelHandlerContext, + removalToken: ChannelHandlerContext.RemovalToken + ) { + // Forward any buffered reads. + while let read = self.bufferedReads.popFirst() { + context.fireChannelRead(read) + } + context.leavePipeline(removalToken: removalToken) + } +} + +// MARK: - HTTP Version Parser + +struct HTTPVersionParser { + /// HTTP/2 connection preface bytes. See RFC 7540 § 5.3. + private static let http2ClientMagic = [ + UInt8(ascii: "P"), + UInt8(ascii: "R"), + UInt8(ascii: "I"), + UInt8(ascii: " "), + UInt8(ascii: "*"), + UInt8(ascii: " "), + UInt8(ascii: "H"), + UInt8(ascii: "T"), + UInt8(ascii: "T"), + UInt8(ascii: "P"), + UInt8(ascii: "/"), + UInt8(ascii: "2"), + UInt8(ascii: "."), + UInt8(ascii: "0"), + UInt8(ascii: "\r"), + UInt8(ascii: "\n"), + UInt8(ascii: "\r"), + UInt8(ascii: "\n"), + UInt8(ascii: "S"), + UInt8(ascii: "M"), + UInt8(ascii: "\r"), + UInt8(ascii: "\n"), + UInt8(ascii: "\r"), + UInt8(ascii: "\n"), + ] + + /// Determines whether the bytes in the `ByteBuffer` are prefixed with the HTTP/2 client + /// connection preface. + static func prefixedWithHTTP2ConnectionPreface(_ buffer: ByteBuffer) -> Bool { + let view = buffer.readableBytesView + + guard view.count >= HTTPVersionParser.http2ClientMagic.count else { + // Not enough bytes. + return false + } + + let slice = view[view.startIndex ..< view.startIndex.advanced(by: self.http2ClientMagic.count)] + return slice.elementsEqual(HTTPVersionParser.http2ClientMagic) + } + + private static let http1_1 = [ + UInt8(ascii: "H"), + UInt8(ascii: "T"), + UInt8(ascii: "T"), + UInt8(ascii: "P"), + UInt8(ascii: "/"), + UInt8(ascii: "1"), + UInt8(ascii: "."), + UInt8(ascii: "1"), + ] + + /// Determines whether the bytes in the `ByteBuffer` are prefixed with an HTTP/1.1 request line. + static func prefixedWithHTTP1RequestLine(_ buffer: ByteBuffer) -> Bool { + var readableBytesView = buffer.readableBytesView + + // From RFC 2616 § 5.1: + // Request-Line = Method SP Request-URI SP HTTP-Version CRLF + + // Read off the Method and Request-URI (and spaces). + guard readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil, + readableBytesView.trimPrefix(to: UInt8(ascii: " ")) != nil else { + return false + } + + // Read off the HTTP-Version and CR. + guard let versionView = readableBytesView.trimPrefix(to: UInt8(ascii: "\r")) else { + return false + } + + // Check that the LF followed the CR. + guard readableBytesView.first == UInt8(ascii: "\n") else { + return false + } + + // Now check the HTTP version. + return versionView.elementsEqual(HTTPVersionParser.http1_1) + } +} diff --git a/Sources/GRPC/GRPCServerRequestRoutingHandler.swift b/Sources/GRPC/GRPCServerRequestRoutingHandler.swift index 03e2e326b..23b861ca3 100644 --- a/Sources/GRPC/GRPCServerRequestRoutingHandler.swift +++ b/Sources/GRPC/GRPCServerRequestRoutingHandler.swift @@ -81,10 +81,10 @@ struct CallPath { extension Collection where Self == Self.SubSequence, Self.Element: Equatable { /// Trims out the prefix up to `separator`, and returns it. /// Sets self to the subsequence after the separator, and returns the subsequence before the separator. - /// If self is emtpy returns `nil` + /// If self is empty returns `nil` /// - parameters: /// - separator : The Element between the head which is returned and the rest which is left in self. - /// - returns: SubSequence containing everything between the beginnning and the first occurance of + /// - returns: SubSequence containing everything between the beginning and the first occurrence of /// `separator`. If `separator` is not found this will be the entire Collection. If the collection is empty /// returns `nil` mutating func trimPrefix(to separator: Element) -> SubSequence? { diff --git a/Sources/GRPC/HTTPProtocolSwitcher.swift b/Sources/GRPC/HTTPProtocolSwitcher.swift deleted file mode 100644 index 90b24c372..000000000 --- a/Sources/GRPC/HTTPProtocolSwitcher.swift +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright 2019, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -import Foundation -import Logging -import NIO -import NIOHTTP1 -import NIOHTTP2 - -/// Channel handler that creates different processing pipelines depending on whether -/// the incoming request is HTTP 1 or 2. -internal class HTTPProtocolSwitcher { - private let handlersInitializer: (Channel, Logger) -> EventLoopFuture - private let errorDelegate: ServerErrorDelegate? - private let logger: Logger - private let httpTargetWindowSize: Int - private let keepAlive: ServerConnectionKeepalive - private let idleTimeout: TimeAmount - private let scheme: String - - // We could receive additional data after the initial data and before configuring - // the pipeline; buffer it and fire it down the pipeline once it is configured. - private enum State { - case notConfigured - case configuring - case configured - } - - private var state: State = .notConfigured - private var bufferedData: [NIOAny] = [] - - init( - errorDelegate: ServerErrorDelegate?, - httpTargetWindowSize: Int = 65535, - keepAlive: ServerConnectionKeepalive, - idleTimeout: TimeAmount, - scheme: String, - logger: Logger, - handlersInitializer: @escaping (Channel, Logger) -> EventLoopFuture - ) { - self.errorDelegate = errorDelegate - self.httpTargetWindowSize = httpTargetWindowSize - self.keepAlive = keepAlive - self.idleTimeout = idleTimeout - self.scheme = scheme - self.logger = logger - self.handlersInitializer = handlersInitializer - } -} - -extension HTTPProtocolSwitcher: ChannelInboundHandler, RemovableChannelHandler { - typealias InboundIn = ByteBuffer - typealias InboundOut = ByteBuffer - - enum HTTPProtocolVersionError: Error { - /// Raised when it wasn't possible to detect HTTP Protocol version. - case invalidHTTPProtocolVersion - - var localizedDescription: String { - switch self { - case .invalidHTTPProtocolVersion: - return "Could not identify HTTP Protocol Version" - } - } - } - - /// HTTP Protocol Version type - enum HTTPProtocolVersion { - case http1 - case http2 - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - switch self.state { - case .notConfigured: - self.logger.debug("determining http protocol version") - self.state = .configuring - self.logger.debug("buffering data", metadata: ["data": "\(data)"]) - self.bufferedData.append(data) - - // Detect the HTTP protocol version for the incoming request, or error out if it - // couldn't be detected. - var inBuffer = self.unwrapInboundIn(data) - guard let initialData = inBuffer.readString(length: inBuffer.readableBytes), - let firstLine = initialData.split( - separator: "\r\n", - maxSplits: 1, - omittingEmptySubsequences: true - ).first else { - self.logger.error("unable to determine http version") - context.fireErrorCaught(HTTPProtocolVersionError.invalidHTTPProtocolVersion) - return - } - - let version: HTTPProtocolVersion - - if firstLine.contains("HTTP/2") { - version = .http2 - } else if firstLine.contains("HTTP/1") { - version = .http1 - } else { - self.logger.error("unable to determine http version") - context.fireErrorCaught(HTTPProtocolVersionError.invalidHTTPProtocolVersion) - return - } - - self.logger.debug("determined http version", metadata: ["http_version": "\(version)"]) - - // Once configured remove ourself from the pipeline, or handle the error. - let pipelineConfigured: EventLoopPromise = context.eventLoop.makePromise() - pipelineConfigured.futureResult.whenComplete { result in - switch result { - case .success: - context.pipeline.removeHandler(context: context, promise: nil) - - case let .failure(error): - self.state = .notConfigured - self.errorCaught(context: context, error: error) - } - } - - // Depending on whether it is HTTP1 or HTTP2, create different processing pipelines. - // Inbound handlers in handlersInitializer should expect HTTPServerRequestPart objects - // and outbound handlers should return HTTPServerResponsePart objects. - switch version { - case .http1: - // Upgrade connections are not handled since gRPC connections already arrive in HTTP2, - // while gRPC-Web does not support HTTP2 at all, so there are no compelling use cases - // to support this. - context.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { - context.pipeline.addHandlers([ - WebCORSHandler(), - GRPCWebToHTTP2ServerCodec(scheme: self.scheme), - ]) - }.flatMap { - self.handlersInitializer(context.channel, self.logger) - }.cascade(to: pipelineConfigured) - - case .http2: - context.channel.configureHTTP2Pipeline( - mode: .server, - targetWindowSize: self.httpTargetWindowSize - ) { streamChannel in - var logger = self.logger - - // Grab the streamID from the channel. - return streamChannel.getOption(HTTP2StreamChannelOptions.streamID).map { streamID in - logger[metadataKey: MetadataKey.h2StreamID] = "\(streamID)" - return logger - }.recover { _ in - logger[metadataKey: MetadataKey.h2StreamID] = "" - return logger - }.flatMap { logger in - self.handlersInitializer(streamChannel, logger) - } - }.flatMap { multiplexer -> EventLoopFuture in - // Add a keepalive and idle handlers between the two HTTP2 handlers. - let keepaliveHandler = GRPCServerKeepaliveHandler(configuration: self.keepAlive) - let idleHandler = GRPCIdleHandler( - mode: .server, - logger: self.logger, - idleTimeout: self.idleTimeout - ) - return context.channel.pipeline.addHandlers( - [keepaliveHandler, idleHandler], - position: .before(multiplexer) - ) - } - .cascade(to: pipelineConfigured) - } - - case .configuring: - self.logger.debug("buffering data", metadata: ["data": "\(data)"]) - self.bufferedData.append(data) - - case .configured: - self.logger - .critical( - "unexpectedly received data; this handler should have been removed from the pipeline" - ) - assertionFailure( - "unexpectedly received data; this handler should have been removed from the pipeline" - ) - } - } - - func removeHandler( - context: ChannelHandlerContext, - removalToken: ChannelHandlerContext.RemovalToken - ) { - self.logger.debug("unbuffering data") - self.bufferedData.forEach { - context.fireChannelRead($0) - } - - context.leavePipeline(removalToken: removalToken) - self.state = .configured - } - - func errorCaught(context: ChannelHandlerContext, error: Error) { - switch self.state { - case .notConfigured, .configuring: - let baseError: Error - - if let errorWithContext = error as? GRPCError.WithContext { - baseError = errorWithContext.error - } else { - baseError = error - } - - self.errorDelegate?.observeLibraryError(baseError) - context.close(mode: .all, promise: nil) - - case .configured: - // If we're configured we will rely on a handler further down the pipeline. - context.fireErrorCaught(error) - } - } -} diff --git a/Sources/GRPC/Server.swift b/Sources/GRPC/Server.swift index c720c39ba..1313614d9 100644 --- a/Sources/GRPC/Server.swift +++ b/Sources/GRPC/Server.swift @@ -26,11 +26,11 @@ import NIOTransportServices /// The pipeline is configured in three stages detailed below. Note: handlers marked with /// a '*' are responsible for handling errors. /// -/// 1. Initial stage, prior to HTTP protocol detection. +/// 1. Initial stage, prior to pipeline configuration. /// -/// ┌───────────────────────────┐ -/// │ HTTPProtocolSwitcher* │ -/// └─▲───────────────────────┬─┘ +/// ┌─────────────────────────────────┐ +/// │ GRPCServerPipelineConfigurator* │ +/// └────▲───────────────────────┬────┘ /// ByteBuffer│ │ByteBuffer /// ┌─┴───────────────────────▼─┐ /// │ NIOSSLHandler │ @@ -39,19 +39,19 @@ import NIOTransportServices /// │ ▼ /// /// The `NIOSSLHandler` is optional and depends on how the framework user has configured -/// their server. The `HTTPProtocolSwitcher` detects which HTTP version is being used and +/// their server. The `GRPCServerPipelineConfigurator` detects which HTTP version is being used +/// (via ALPN if TLS is used or by parsing the first bytes on the connection otherwise) and /// configures the pipeline accordingly. /// /// 2. HTTP version detected. "HTTP Handlers" depends on the HTTP version determined by -/// `HTTPProtocolSwitcher`. All of these handlers are provided by NIO except for the -/// `WebCORSHandler` which is used for HTTP/1. +/// `GRPCServerPipelineConfigurator`. In the case of HTTP/2: /// /// ┌─────────────────────────────────┐ -/// │ GRPCServerRequestRoutingHandler │ +/// │ HTTP2StreamMultiplexer │ /// └─▲─────────────────────────────┬─┘ -/// HTTPServerRequestPart│ │HTTPServerResponsePart +/// HTTP2Frame│ │HTTP2Frame /// ┌─┴─────────────────────────────▼─┐ -/// │ HTTP Handlers │ +/// │ HTTP2Handler │ /// └─▲─────────────────────────────┬─┘ /// ByteBuffer│ │ByteBuffer /// ┌─┴─────────────────────────────▼─┐ @@ -60,28 +60,20 @@ import NIOTransportServices /// ByteBuffer│ │ByteBuffer /// │ ▼ /// -/// The `GRPCServerRequestRoutingHandler` resolves the request head and configures the rest of -/// the pipeline based on the RPC call being made. +/// The `HTTP2StreamMultiplexer` provides one `Channel` for each HTTP/2 stream (and thus each +/// RPC). /// -/// 3. The call has been resolved and is a function that this server can handle. Responses are -/// written into `BaseCallHandler` by a user-implemented `CallHandlerProvider`. +/// 3. The frames for each stream channel are routed by the `HTTP2ToRawGRPCServerCodec` handler to +/// a handler containing the user-implemented logic provided by a `CallHandlerProvider`: /// /// ┌─────────────────────────────────┐ /// │ BaseCallHandler* │ /// └─▲─────────────────────────────┬─┘ -/// GRPCServerRequestPart│ │GRPCServerResponsePart -/// ┌─┴─────────────────────────────▼─┐ -/// │ HTTP1ToGRPCServerCodec │ -/// └─▲─────────────────────────────┬─┘ -/// HTTPServerRequestPart│ │HTTPServerResponsePart +/// GRPCServerRequestPart│ │GRPCServerResponsePart /// ┌─┴─────────────────────────────▼─┐ -/// │ HTTP Handlers │ +/// │ HTTP2ToRawGRPCServerCodec │ /// └─▲─────────────────────────────┬─┘ -/// ByteBuffer│ │ByteBuffer -/// ┌─┴─────────────────────────────▼─┐ -/// │ NIOSSLHandler │ -/// └─▲─────────────────────────────┬─┘ -/// ByteBuffer│ │ByteBuffer +/// HTTP2Frame.FramePayload│ │HTTP2Frame.FramePayload /// │ ▼ /// public final class Server { @@ -103,37 +95,20 @@ public final class Server { ) // Set the handlers that are applied to the accepted Channels .childChannelInitializer { channel in - var logger = configuration.logger - logger[metadataKey: MetadataKey.connectionID] = "\(UUID().uuidString)" - logger[metadataKey: MetadataKey.remoteAddress] = channel.remoteAddress + var configuration = configuration + configuration.logger[metadataKey: MetadataKey.connectionID] = "\(UUID().uuidString)" + configuration.logger[metadataKey: MetadataKey.remoteAddress] = channel.remoteAddress .map { "\($0)" } ?? "n/a" - let protocolSwitcher = HTTPProtocolSwitcher( - errorDelegate: configuration.errorDelegate, - httpTargetWindowSize: configuration.httpTargetWindowSize, - keepAlive: configuration.connectionKeepalive, - idleTimeout: configuration.connectionIdleTimeout, - scheme: configuration.tls == nil ? "http" : "https", - logger: logger - ) { (channel, logger) -> EventLoopFuture in - let handler = HTTP2ToRawGRPCServerCodec( - servicesByName: configuration.serviceProvidersByName, - encoding: configuration.messageEncoding, - errorDelegate: configuration.errorDelegate, - normalizeHeaders: true, - logger: logger - ) - return channel.pipeline.addHandler(handler) - } - var configured: EventLoopFuture + let configurator = GRPCServerPipelineConfigurator(configuration: configuration) if let tls = configuration.tls { configured = channel.configureTLS(configuration: tls).flatMap { - channel.pipeline.addHandler(protocolSwitcher) + channel.pipeline.addHandler(configurator) } } else { - configured = channel.pipeline.addHandler(protocolSwitcher) + configured = channel.pipeline.addHandler(configurator) } // Work around the zero length write issue, if needed. @@ -276,7 +251,7 @@ extension Server { /// /// This is how gRPC consumes the service providers internally. Caching this as stored data avoids /// the need to recalculate this dictionary each time we receive an rpc. - fileprivate private(set) var serviceProvidersByName: [Substring: CallHandlerProvider] + internal var serviceProvidersByName: [Substring: CallHandlerProvider] /// Create a `Configuration` with some pre-defined defaults. /// diff --git a/Sources/GRPC/ServerBuilder.swift b/Sources/GRPC/ServerBuilder.swift index 9653f8282..5bb439a51 100644 --- a/Sources/GRPC/ServerBuilder.swift +++ b/Sources/GRPC/ServerBuilder.swift @@ -124,6 +124,18 @@ extension Server.Builder.Secure { self.tls.certificateVerification = certificateVerification return self } + + /// Sets whether the server's TLS handshake requires a protocol to be negotiated via ALPN. This + /// defaults to `true` if not otherwise set. + /// + /// If this option is set to `false` and no protocol is negotiated via ALPN then the server will + /// parse the initial bytes on the connection to determine whether HTTP/2 or HTTP/1.1 (gRPC-Web) + /// is being used and configure the connection appropriately. + @discardableResult + public func withTLS(requiringALPN: Bool) -> Self { + self.tls.requireALPN = requiringALPN + return self + } } extension Server.Builder { diff --git a/Sources/GRPC/TLSConfiguration.swift b/Sources/GRPC/TLSConfiguration.swift index 05ccd71e1..746ffede1 100644 --- a/Sources/GRPC/TLSConfiguration.swift +++ b/Sources/GRPC/TLSConfiguration.swift @@ -96,7 +96,7 @@ extension ClientConnection.Configuration { trustRoots: trustRoots, certificateChain: certificateChain, privateKey: privateKey, - applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue } + applicationProtocols: GRPCApplicationProtocolIdentifier.client ) self.hostnameOverride = hostnameOverride } @@ -118,6 +118,10 @@ extension Server.Configuration { public struct TLS { public private(set) var configuration: TLSConfiguration + /// Whether ALPN is required. Disabling this option may be useful in cases where ALPN is not + /// supported. + public var requireALPN: Bool = true + /// The certificates to offer during negotiation. If not present, no certificates will be /// offered. public var certificateChain: [NIOSSLCertificateSource] { @@ -171,11 +175,13 @@ extension Server.Configuration { /// root provided by the platform. /// - Parameter certificateVerification: Whether to verify the remote certificate. Defaults to /// `.none`. + /// - Parameter requireALPN: Whether ALPN is required or not. public init( certificateChain: [NIOSSLCertificateSource], privateKey: NIOSSLPrivateKeySource, trustRoots: NIOSSLTrustRoots = .default, - certificateVerification: CertificateVerification = .none + certificateVerification: CertificateVerification = .none, + requireALPN: Bool = true ) { self.configuration = .forServer( certificateChain: certificateChain, @@ -183,13 +189,15 @@ extension Server.Configuration { minimumTLSVersion: .tlsv12, certificateVerification: certificateVerification, trustRoots: trustRoots, - applicationProtocols: GRPCApplicationProtocolIdentifier.allCases.map { $0.rawValue } + applicationProtocols: GRPCApplicationProtocolIdentifier.server ) + self.requireALPN = requireALPN } /// Creates a TLS Configuration using the given `NIOSSL.TLSConfiguration`. - public init(configuration: TLSConfiguration) { + public init(configuration: TLSConfiguration, requireALPN: Bool = true) { self.configuration = configuration + self.requireALPN = requireALPN } } } diff --git a/Sources/GRPC/TLSVerificationHandler.swift b/Sources/GRPC/TLSVerificationHandler.swift index 77cc56aea..e7e7530ad 100644 --- a/Sources/GRPC/TLSVerificationHandler.swift +++ b/Sources/GRPC/TLSVerificationHandler.swift @@ -20,11 +20,26 @@ import NIOSSL import NIOTLS /// Application protocol identifiers for ALPN. -internal enum GRPCApplicationProtocolIdentifier: String, CaseIterable { - // This is not in the IANA ALPN protocol ID registry, but may be used by servers to indicate that - // they serve only gRPC traffic. It is part of the gRPC core implementation. - case gRPC = "grpc-exp" - case h2 +internal enum GRPCApplicationProtocolIdentifier { + static let gRPC = "grpc-exp" + static let h2 = "h2" + static let http1_1 = "http/1.1" + + static let client = [gRPC, h2] + static let server = [gRPC, h2, http1_1] + + static func isHTTP2Like(_ value: String) -> Bool { + switch value { + case self.gRPC, self.h2: + return true + default: + return false + } + } + + static func isHTTP1(_ value: String) -> Bool { + return value == self.http1_1 + } } internal class TLSVerificationHandler: ChannelInboundHandler, RemovableChannelHandler { diff --git a/Tests/GRPCTests/GRPCServerPipelineConfiguratorTests.swift b/Tests/GRPCTests/GRPCServerPipelineConfiguratorTests.swift new file mode 100644 index 000000000..9ca04f350 --- /dev/null +++ b/Tests/GRPCTests/GRPCServerPipelineConfiguratorTests.swift @@ -0,0 +1,210 @@ +/* + * Copyright 2020, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +@testable import GRPC +import NIO +import NIOHTTP2 +import NIOTLS +import XCTest + +class GRPCServerPipelineConfiguratorTests: GRPCTestCase { + private var channel: EmbeddedChannel! + + private func assertConfigurator(isPresent: Bool) { + assertThat( + try self.channel.pipeline.handler(type: GRPCServerPipelineConfigurator.self).wait(), + isPresent ? .doesNotThrow() : .throws() + ) + } + + private func assertHTTP2Handler(isPresent: Bool) { + assertThat( + try self.channel.pipeline.handler(type: NIOHTTP2Handler.self).wait(), + isPresent ? .doesNotThrow() : .throws() + ) + } + + private func assertGRPCWebToHTTP2Handler(isPresent: Bool) { + assertThat( + try self.channel.pipeline.handler(type: GRPCWebToHTTP2ServerCodec.self).wait(), + isPresent ? .doesNotThrow() : .throws() + ) + } + + private func setUp(tls: Bool, requireALPN: Bool = true) { + self.channel = EmbeddedChannel() + + var configuration = Server.Configuration( + target: .unixDomainSocket("/ignored"), + eventLoopGroup: self.channel.eventLoop, + serviceProviders: [], + logger: self.serverLogger + ) + + if tls { + configuration.tls = .init( + certificateChain: [], + privateKey: .file("not used"), + requireALPN: requireALPN + ) + } + + let handler = GRPCServerPipelineConfigurator(configuration: configuration) + assertThat(try self.channel.pipeline.addHandler(handler).wait(), .doesNotThrow()) + } + + func testHTTP2SetupViaALPN() { + self.setUp(tls: true, requireALPN: true) + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: "h2") + self.channel.pipeline.fireUserInboundEventTriggered(event) + self.assertConfigurator(isPresent: false) + self.assertHTTP2Handler(isPresent: true) + } + + func testGRPCExpSetupViaALPN() { + self.setUp(tls: true, requireALPN: true) + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: "grpc-exp") + self.channel.pipeline.fireUserInboundEventTriggered(event) + self.assertConfigurator(isPresent: false) + self.assertHTTP2Handler(isPresent: true) + } + + func testHTTP1Dot1SetupViaALPN() { + self.setUp(tls: true, requireALPN: true) + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: "http/1.1") + self.channel.pipeline.fireUserInboundEventTriggered(event) + self.assertConfigurator(isPresent: false) + self.assertGRPCWebToHTTP2Handler(isPresent: true) + } + + func testUnrecognisedALPNCloses() { + self.setUp(tls: true, requireALPN: true) + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: "unsupported") + self.channel.pipeline.fireUserInboundEventTriggered(event) + self.channel.embeddedEventLoop.run() + assertThat(try self.channel.closeFuture.wait(), .doesNotThrow()) + } + + func testNoNegotiatedProtocolCloses() { + self.setUp(tls: true, requireALPN: true) + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: nil) + self.channel.pipeline.fireUserInboundEventTriggered(event) + self.channel.embeddedEventLoop.run() + assertThat(try self.channel.closeFuture.wait(), .doesNotThrow()) + } + + func testNoNegotiatedProtocolFallbackToBytesWhenALPNNotRequired() throws { + self.setUp(tls: true, requireALPN: false) + + // Require ALPN is disabled, so this is a no-op. + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: nil) + self.channel.pipeline.fireUserInboundEventTriggered(event) + + // Configure via bytes. + let bytes = ByteBuffer(staticString: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + assertThat(try self.channel.writeInbound(bytes), .doesNotThrow()) + self.assertConfigurator(isPresent: false) + self.assertHTTP2Handler(isPresent: true) + } + + func testHTTP2SetupViaBytes() { + self.setUp(tls: false) + let bytes = ByteBuffer(staticString: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + assertThat(try self.channel.writeInbound(bytes), .doesNotThrow()) + self.assertConfigurator(isPresent: false) + self.assertHTTP2Handler(isPresent: true) + } + + func testHTTP1Dot1SetupViaBytes() { + self.setUp(tls: false) + let bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n") + assertThat(try self.channel.writeInbound(bytes), .doesNotThrow()) + self.assertConfigurator(isPresent: false) + self.assertGRPCWebToHTTP2Handler(isPresent: true) + } + + func testReadsAreUnbufferedAfterConfiguration() throws { + self.setUp(tls: false) + + var bytes = ByteBuffer(staticString: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + // A SETTINGS frame MUST follow the connection preface. Append one so that the HTTP/2 handler + // responds with its initial settings (and we validate that we forward frames once configuring). + let emptySettingsFrameBytes: [UInt8] = [ + 0x00, 0x00, 0x00, // 3-byte payload length (0 bytes) + 0x04, // 1-byte frame type (SETTINGS) + 0x00, // 1-byte flags (none) + 0x00, 0x00, 0x00, 0x00, // 4-byte stream identifier + ] + bytes.writeBytes(emptySettingsFrameBytes) + + // Do the setup. + assertThat(try self.channel.writeInbound(bytes), .doesNotThrow()) + self.assertConfigurator(isPresent: false) + self.assertHTTP2Handler(isPresent: true) + + // We expect the server to respond with a SETTINGS frame now. + let ioData = try channel.readOutbound(as: IOData.self) + switch ioData { + case var .some(.byteBuffer(buffer)): + if let frame = buffer.readBytes(length: 9) { + // Just check it's a SETTINGS frame. + assertThat(frame[3], .is(0x04)) + } else { + XCTFail("Expected more bytes") + } + + default: + XCTFail("Expected ByteBuffer but got \(String(describing: ioData))") + } + } + + func testALPNIsPreferredOverBytes() throws { + self.setUp(tls: true, requireALPN: true) + + // Write in an HTTP/1 request line. This should just be buffered. + let bytes = ByteBuffer(staticString: "GET http://www.foo.bar HTTP/1.1\r\n") + assertThat(try self.channel.writeInbound(bytes), .doesNotThrow()) + + self.assertConfigurator(isPresent: true) + self.assertHTTP2Handler(isPresent: false) + self.assertGRPCWebToHTTP2Handler(isPresent: false) + + // Now configure HTTP/2 with ALPN. This should be used to configure the pipeline. + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: "h2") + self.channel.pipeline.fireUserInboundEventTriggered(event) + + self.assertConfigurator(isPresent: false) + self.assertGRPCWebToHTTP2Handler(isPresent: false) + self.assertHTTP2Handler(isPresent: true) + } + + func testALPNFallbackToAlreadyBufferedBytes() throws { + self.setUp(tls: true, requireALPN: false) + + // Write in an HTTP/2 connection preface. This should just be buffered. + let bytes = ByteBuffer(staticString: "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + assertThat(try self.channel.writeInbound(bytes), .doesNotThrow()) + + self.assertConfigurator(isPresent: true) + self.assertHTTP2Handler(isPresent: false) + + // Complete the handshake with no protocol negotiated, we should fallback to the buffered bytes. + let event = TLSUserEvent.handshakeCompleted(negotiatedProtocol: nil) + self.channel.pipeline.fireUserInboundEventTriggered(event) + + self.assertConfigurator(isPresent: false) + self.assertHTTP2Handler(isPresent: true) + } +} diff --git a/Tests/GRPCTests/HTTPVersionParserTests.swift b/Tests/GRPCTests/HTTPVersionParserTests.swift new file mode 100644 index 000000000..455f85af1 --- /dev/null +++ b/Tests/GRPCTests/HTTPVersionParserTests.swift @@ -0,0 +1,79 @@ +/* + * Copyright 2020, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +@testable import GRPC +import NIO +import XCTest + +class HTTPVersionParserTests: GRPCTestCase { + private let preface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + + func testHTTP2ExactlyTheRightBytes() { + let buffer = ByteBuffer(string: self.preface) + XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer)) + } + + func testHTTP2TheRightBytesAndMore() { + var buffer = ByteBuffer(string: self.preface) + buffer.writeRepeatingByte(42, count: 1024) + XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer)) + } + + func testHTTP2NoBytes() { + let empty = ByteBuffer() + XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(empty)) + } + + func testHTTP2NotEnoughBytes() { + var buffer = ByteBuffer(string: self.preface) + buffer.moveWriterIndex(to: buffer.writerIndex - 1) + XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer)) + } + + func testHTTP2EnoughOfTheWrongBytes() { + let buffer = ByteBuffer(string: String(self.preface.reversed())) + XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP2ConnectionPreface(buffer)) + } + + func testHTTP1RequestLine() { + let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\n") + XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer)) + } + + func testHTTP1RequestLineAndMore() { + let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1\r\nMore") + XCTAssertTrue(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer)) + } + + func testHTTP1RequestLineWithoutCRLF() { + let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html HTTP/1.1") + XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer)) + } + + func testHTTP1NoBytes() { + let empty = ByteBuffer() + XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(empty)) + } + + func testHTTP1IncompleteRequestLine() { + let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html") + XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer)) + } + + func testHTTP1MalformedVersion() { + let buffer = ByteBuffer(staticString: "GET https://grpc.io/index.html ptth/1.1\r\n") + XCTAssertFalse(HTTPVersionParser.prefixedWithHTTP1RequestLine(buffer)) + } +} diff --git a/Tests/GRPCTests/XCTestManifests.swift b/Tests/GRPCTests/XCTestManifests.swift index 2df87599a..ef3c5c8f3 100644 --- a/Tests/GRPCTests/XCTestManifests.swift +++ b/Tests/GRPCTests/XCTestManifests.swift @@ -648,6 +648,25 @@ extension GRPCSecureInteroperabilityTests { ] } +extension GRPCServerPipelineConfiguratorTests { + // DO NOT MODIFY: This is autogenerated, use: + // `swift test --generate-linuxmain` + // to regenerate. + static let __allTests__GRPCServerPipelineConfiguratorTests = [ + ("testALPNFallbackToAlreadyBufferedBytes", testALPNFallbackToAlreadyBufferedBytes), + ("testALPNIsPreferredOverBytes", testALPNIsPreferredOverBytes), + ("testGRPCExpSetupViaALPN", testGRPCExpSetupViaALPN), + ("testHTTP1Dot1SetupViaALPN", testHTTP1Dot1SetupViaALPN), + ("testHTTP1Dot1SetupViaBytes", testHTTP1Dot1SetupViaBytes), + ("testHTTP2SetupViaALPN", testHTTP2SetupViaALPN), + ("testHTTP2SetupViaBytes", testHTTP2SetupViaBytes), + ("testNoNegotiatedProtocolCloses", testNoNegotiatedProtocolCloses), + ("testNoNegotiatedProtocolFallbackToBytesWhenALPNNotRequired", testNoNegotiatedProtocolFallbackToBytesWhenALPNNotRequired), + ("testReadsAreUnbufferedAfterConfiguration", testReadsAreUnbufferedAfterConfiguration), + ("testUnrecognisedALPNCloses", testUnrecognisedALPNCloses), + ] +} + extension GRPCStatusCodeTests { // DO NOT MODIFY: This is autogenerated, use: // `swift test --generate-linuxmain` @@ -764,6 +783,25 @@ extension HTTP2ToRawGRPCStateMachineTests { ] } +extension HTTPVersionParserTests { + // DO NOT MODIFY: This is autogenerated, use: + // `swift test --generate-linuxmain` + // to regenerate. + static let __allTests__HTTPVersionParserTests = [ + ("testHTTP1IncompleteRequestLine", testHTTP1IncompleteRequestLine), + ("testHTTP1MalformedVersion", testHTTP1MalformedVersion), + ("testHTTP1NoBytes", testHTTP1NoBytes), + ("testHTTP1RequestLine", testHTTP1RequestLine), + ("testHTTP1RequestLineAndMore", testHTTP1RequestLineAndMore), + ("testHTTP1RequestLineWithoutCRLF", testHTTP1RequestLineWithoutCRLF), + ("testHTTP2EnoughOfTheWrongBytes", testHTTP2EnoughOfTheWrongBytes), + ("testHTTP2ExactlyTheRightBytes", testHTTP2ExactlyTheRightBytes), + ("testHTTP2NoBytes", testHTTP2NoBytes), + ("testHTTP2NotEnoughBytes", testHTTP2NotEnoughBytes), + ("testHTTP2TheRightBytesAndMore", testHTTP2TheRightBytesAndMore), + ] +} + extension HeaderNormalizationTests { // DO NOT MODIFY: This is autogenerated, use: // `swift test --generate-linuxmain` @@ -1143,6 +1181,7 @@ public func __allTests() -> [XCTestCaseEntry] { testCase(GRPCInsecureInteroperabilityTests.__allTests__GRPCInsecureInteroperabilityTests), testCase(GRPCPingHandlerTests.__allTests__GRPCPingHandlerTests), testCase(GRPCSecureInteroperabilityTests.__allTests__GRPCSecureInteroperabilityTests), + testCase(GRPCServerPipelineConfiguratorTests.__allTests__GRPCServerPipelineConfiguratorTests), testCase(GRPCStatusCodeTests.__allTests__GRPCStatusCodeTests), testCase(GRPCStatusMessageMarshallerTests.__allTests__GRPCStatusMessageMarshallerTests), testCase(GRPCStatusTests.__allTests__GRPCStatusTests), @@ -1150,6 +1189,7 @@ public func __allTests() -> [XCTestCaseEntry] { testCase(GRPCTypeSizeTests.__allTests__GRPCTypeSizeTests), testCase(GRPCWebToHTTP2ServerCodecTests.__allTests__GRPCWebToHTTP2ServerCodecTests), testCase(HTTP2ToRawGRPCStateMachineTests.__allTests__HTTP2ToRawGRPCStateMachineTests), + testCase(HTTPVersionParserTests.__allTests__HTTPVersionParserTests), testCase(HeaderNormalizationTests.__allTests__HeaderNormalizationTests), testCase(ImmediatelyFailingProviderTests.__allTests__ImmediatelyFailingProviderTests), testCase(InterceptorsTests.__allTests__InterceptorsTests),