Skip to content

Commit

Permalink
Use MiddlewareProtocol for middleware
Browse files Browse the repository at this point in the history
This is a stepping stone to including a separate result builder router. We need the middleware format to be the same between the default router and the result builder version.
  • Loading branch information
adam-fowler committed Dec 5, 2023
1 parent fadc621 commit 9207774
Show file tree
Hide file tree
Showing 14 changed files with 105 additions and 92 deletions.
10 changes: 5 additions & 5 deletions Sources/Hummingbird/Middleware/CORSMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2021 the Hummingbird authors
// Copyright (c) 2021-2023 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -21,7 +21,7 @@ import NIOCore
/// then return an empty body with all the standard CORS headers otherwise send
/// request onto the next handler and when you receive the response add a
/// "access-control-allow-origin" header
public struct HBCORSMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
public struct HBCORSMiddleware<Context: HBBaseRequestContext>: HBMiddlewareProtocol {
/// Defines what origins are allowed
public enum AllowOrigin: Sendable {
case none
Expand Down Expand Up @@ -84,10 +84,10 @@ public struct HBCORSMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
}

/// apply CORS middleware
public func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse {
public func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse {
// if no origin header then don't apply CORS
guard request.headers[.origin] != nil else {
return try await next.respond(to: request, context: context)
return try await next(request, context)
}

if request.method == .options {
Expand All @@ -113,7 +113,7 @@ public struct HBCORSMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
return HBResponse(status: .noContent, headers: headers, body: .init())
} else {
// if not OPTIONS then run rest of middleware chain and add origin value at the end
var response = try await next.respond(to: request, context: context)
var response = try await next(request, context)
response.headers[.accessControlAllowOrigin] = self.allowOrigin.value(for: request) ?? ""
if self.allowCredentials {
response.headers[.accessControlAllowCredentials] = "true"
Expand Down
8 changes: 4 additions & 4 deletions Sources/Hummingbird/Middleware/LogRequestMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2021 the Hummingbird authors
// Copyright (c) 2021-2023 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -15,7 +15,7 @@
import Logging

/// Middleware outputting to log for every call to server
public struct HBLogRequestsMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
public struct HBLogRequestsMiddleware<Context: HBBaseRequestContext>: HBMiddlewareProtocol {
let logLevel: Logger.Level
let includeHeaders: Bool

Expand All @@ -24,7 +24,7 @@ public struct HBLogRequestsMiddleware<Context: HBBaseRequestContext>: HBMiddlewa
self.includeHeaders = includeHeaders
}

public func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse {
public func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse {
if self.includeHeaders {
context.logger.log(
level: self.logLevel,
Expand All @@ -38,6 +38,6 @@ public struct HBLogRequestsMiddleware<Context: HBBaseRequestContext>: HBMiddlewa
metadata: ["hb_uri": .stringConvertible(request.uri), "hb_method": .string(request.method.rawValue)]
)
}
return try await next.respond(to: request, context: context)
return try await next(request, context)
}
}
8 changes: 4 additions & 4 deletions Sources/Hummingbird/Middleware/MetricsMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2021 the Hummingbird authors
// Copyright (c) 2021-2023 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -19,14 +19,14 @@ import Metrics
///
/// Records the number of requests, the request duration and how many errors were thrown. Each metric has additional
/// dimensions URI and method.
public struct HBMetricsMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
public struct HBMetricsMiddleware<Context: HBBaseRequestContext>: HBMiddlewareProtocol {
public init() {}

public func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse {
public func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse {
let startTime = DispatchTime.now().uptimeNanoseconds

do {
let response = try await next.respond(to: request, context: context)
let response = try await next(request, context)
// need to create dimensions once request has been responded to ensure
// we have the correct endpoint path
let dimensions: [(String, String)] = [
Expand Down
41 changes: 27 additions & 14 deletions Sources/Hummingbird/Middleware/Middleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2021 the Hummingbird authors
// Copyright (c) 2021-2023 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -14,41 +14,54 @@

import NIOCore

/// Middleware Handler with generic input, context and output types
public typealias Middleware<Input, Output, Context> = @Sendable (Input, Context, _ next: (Input, Context) async throws -> Output) async throws -> Output

/// Middleware protocol with generic input, context and output types
public protocol MiddlewareProtocol<Input, Output, Context>: Sendable {
associatedtype Input
associatedtype Output
associatedtype Context

func handle(_ input: Input, context: Context, next: (Input, Context) async throws -> Output) async throws -> Output
}

/// Applied to `HBRequest` before it is dealt with by the router. Middleware passes the processed request onto the next responder
/// (either the next middleware or the router) by calling `next.apply(to: request)`. If you want to shortcut the request you
/// (either the next middleware or a route) by calling `next(request, context)`. If you want to shortcut the request you
/// can return a response immediately
///
/// Middleware is added to the application by calling `app.middleware.add(MyMiddleware()`.
/// Middleware is added to the application by calling `router.middlewares.add(MyMiddleware()`.
///
/// Middleware allows you to process a request before it reaches your request handler and then process the response
/// returned by that handler.
/// ```
/// func apply(to request: HBRequest, next: HBResponder) async throws -> HBResponse {
/// func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse
/// let request = processRequest(request)
/// let response = try await next.respond(to: request)
/// let response = try await next(request, context)
/// return processResponse(response)
/// }
/// ```
/// Middleware also allows you to shortcut the whole process and not pass on the request to the handler
/// ```
/// func apply(to request: HBRequest, next: HBResponder) async throws -> HBResponse {
/// func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse
/// if request.method == .OPTIONS {
/// return HBResponse(status: .noContent)
/// } else {
/// return try await next.respond(to: request)
/// return try await next(request, context)
/// }
/// }
/// ```
public protocol HBMiddleware<Context>: Sendable {
associatedtype Context
func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse
}

/// Middleware protocol with HBRequest as input and HBResponse as output
public protocol HBMiddlewareProtocol<Context>: MiddlewareProtocol where Input == HBRequest, Output == HBResponse {}

struct MiddlewareResponder<Context>: HBResponder {
let middleware: any HBMiddleware<Context>
let next: any HBResponder<Context>
let middleware: any HBMiddlewareProtocol<Context>
let next: @Sendable (HBRequest, Context) async throws -> HBResponse

func respond(to request: HBRequest, context: Context) async throws -> HBResponse {
return try await self.middleware.apply(to: request, context: context, next: self.next)
return try await self.middleware.handle(request, context: context) { request, context in
try await self.next(request, context)
}
}
}
10 changes: 5 additions & 5 deletions Sources/Hummingbird/Middleware/MiddlewareGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2021 the Hummingbird authors
// Copyright (c) 2021-2023 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -14,7 +14,7 @@

/// Group of middleware that can be used to create a responder chain. Each middleware calls the next one
public final class HBMiddlewareGroup<Context> {
var middlewares: [any HBMiddleware<Context>]
var middlewares: [any HBMiddlewareProtocol<Context>]

/// Initialize `HBMiddlewareGroup`
///
Expand All @@ -24,12 +24,12 @@ public final class HBMiddlewareGroup<Context> {
self.middlewares = []
}

init(middlewares: [any HBMiddleware<Context>]) {
init(middlewares: [any HBMiddlewareProtocol<Context>]) {
self.middlewares = middlewares
}

/// Add middleware to group
public func add(_ middleware: any HBMiddleware<Context>) {
public func add(_ middleware: any HBMiddlewareProtocol<Context>) {
self.middlewares.append(middleware)
}

Expand All @@ -39,7 +39,7 @@ public final class HBMiddlewareGroup<Context> {
public func constructResponder(finalResponder: any HBResponder<Context>) -> any HBResponder<Context> {
var currentResponser = finalResponder
for i in (0..<self.middlewares.count).reversed() {
let responder = MiddlewareResponder(middleware: middlewares[i], next: currentResponser)
let responder = MiddlewareResponder(middleware: middlewares[i], next: currentResponser.respond(to:context:))
currentResponser = responder
}
return currentResponser
Expand Down
6 changes: 3 additions & 3 deletions Sources/Hummingbird/Middleware/SetCodableMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
//
//===----------------------------------------------------------------------===//

public struct HBSetCodableMiddleware<Decoder: HBRequestDecoder, Encoder: HBResponseEncoder, Context: HBBaseRequestContext>: HBMiddleware {
public struct HBSetCodableMiddleware<Decoder: HBRequestDecoder, Encoder: HBResponseEncoder, Context: HBBaseRequestContext>: HBMiddlewareProtocol {
let decoder: @Sendable () -> Decoder
let encoder: @Sendable () -> Encoder

Expand All @@ -21,10 +21,10 @@ public struct HBSetCodableMiddleware<Decoder: HBRequestDecoder, Encoder: HBRespo
self.encoder = encoder
}

public func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse {
public func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse {
var context = context
context.coreContext.requestDecoder = self.decoder()
context.coreContext.responseEncoder = self.encoder()
return try await next.respond(to: request, context: context)
return try await next(request, context)
}
}
6 changes: 3 additions & 3 deletions Sources/Hummingbird/Middleware/TracingMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import Tracing
/// You may opt in to recording a specific subset of HTTP request/response header values by passing
/// a set of header names to ``init(recordingHeaders:)``.
@available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *)
public struct HBTracingMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
public struct HBTracingMiddleware<Context: HBBaseRequestContext>: HBMiddlewareProtocol {
private let headerNamesToRecord: Set<RecordingHeader>

/// Intialize a new HBTracingMiddleware.
Expand All @@ -39,7 +39,7 @@ public struct HBTracingMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
self.init(recordingHeaders: [])
}

public func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse {
public func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse {
var serviceContext = ServiceContext.current ?? ServiceContext.topLevel
InstrumentationSystem.instrument.extract(request.headers, into: &serviceContext, using: HTTPHeadersExtractor())

Expand Down Expand Up @@ -83,7 +83,7 @@ public struct HBTracingMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
}

do {
let response = try await next.respond(to: request, context: context)
let response = try await next(request, context)
span.updateAttributes { attributes in
attributes = self.recordHeaders(response.headers, toSpanAttributes: attributes, withPrefix: "http.response.header.")

Expand Down
6 changes: 3 additions & 3 deletions Sources/Hummingbird/Router/RouterGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public struct HBRouterGroup<Context: HBBaseRequestContext>: HBRouterMethods {
}

/// Add middleware to RouterEndpoint
@discardableResult public func add(middleware: any HBMiddleware<Context>) -> HBRouterGroup<Context> {
@discardableResult public func add(middleware: any HBMiddlewareProtocol<Context>) -> HBRouterGroup<Context> {
self.middlewares.add(middleware)
return self
}
Expand All @@ -58,11 +58,11 @@ public struct HBRouterGroup<Context: HBBaseRequestContext>: HBRouterMethods {
}

/// Add path for closure returning type using async/await
@discardableResult public func on<Output: HBResponseGenerator>(
@discardableResult public func on(
_ path: String = "",
method: HTTPRequest.Method,
options: HBRouterMethodOptions = [],
use closure: @Sendable @escaping (HBRequest, Context) async throws -> Output
use closure: @Sendable @escaping (HBRequest, Context) async throws -> some HBResponseGenerator
) -> Self {
let responder = constructResponder(options: options, use: closure)
let path = self.combinePaths(self.path, path)
Expand Down
4 changes: 2 additions & 2 deletions Sources/Hummingbird/Server/Responder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2021 the Hummingbird authors
// Copyright (c) 2021-2023 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -21,7 +21,7 @@ import ServiceContextModule
public protocol HBResponder<Context>: Sendable {
associatedtype Context
/// Return EventLoopFuture that will be fulfilled with response to the request supplied
func respond(to request: HBRequest, context: Context) async throws -> HBResponse
@Sendable func respond(to request: HBRequest, context: Context) async throws -> HBResponse
}

/// Responder that calls supplied closure
Expand Down
20 changes: 10 additions & 10 deletions Sources/HummingbirdFoundation/Files/FileMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2021 the Hummingbird authors
// Copyright (c) 2021-2023 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
Expand All @@ -29,7 +29,7 @@ import NIOPosix
/// "if-modified-since", "if-none-match", "if-range" and 'range" headers. It will output "content-length",
/// "modified-date", "eTag", "content-type", "cache-control" and "content-range" headers where
/// they are relevant.
public struct HBFileMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
public struct HBFileMiddleware<Context: HBBaseRequestContext>: HBMiddlewareProtocol {
struct IsDirectoryError: Error {}

let rootFolder: URL
Expand Down Expand Up @@ -71,9 +71,9 @@ public struct HBFileMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
logger.info("FileMiddleware serving from \(workingFolder)\(rootFolder)")
}

public func apply(to request: HBRequest, context: Context, next: any HBResponder<Context>) async throws -> HBResponse {
public func handle(_ request: HBRequest, context: Context, next: (HBRequest, Context) async throws -> HBResponse) async throws -> HBResponse {
do {
return try await next.respond(to: request, context: context)
return try await next(request, context)
} catch {
guard let httpError = error as? HBHTTPError, httpError.status == .notFound else {
throw error
Expand Down Expand Up @@ -118,12 +118,12 @@ public struct HBFileMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
var headers = HTTPFields()

// content-length
if let contentSize = contentSize {
if let contentSize {
headers[.contentLength] = String(describing: contentSize)
}
// modified-date
var modificationDateString: String?
if let modificationDate = modificationDate {
if let modificationDate {
modificationDateString = HBDateCache.rfc1123Formatter.string(from: modificationDate)
headers[.lastModified] = modificationDateString!
}
Expand Down Expand Up @@ -158,7 +158,7 @@ public struct HBFileMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
}
// verify if-modified-since
else if let ifModifiedSince = request.headers[.ifModifiedSince],
let modificationDate = modificationDate
let modificationDate
{
if let ifModifiedSinceDate = HBDateCache.rfc1123Formatter.date(from: ifModifiedSince) {
// round modification date of file down to seconds for comparison
Expand All @@ -178,7 +178,7 @@ public struct HBFileMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
if let ifRange = request.headers[.ifRange], ifRange != headers[.eTag], ifRange != headers[.lastModified] {
// do nothing and drop down to returning full file
} else {
if let contentSize = contentSize {
if let contentSize {
let lowerBound = max(range.lowerBound, 0)
let upperBound = min(range.upperBound, contentSize - 1)
headers[.contentRange] = "bytes \(lowerBound)-\(upperBound)/\(contentSize)"
Expand All @@ -197,7 +197,7 @@ public struct HBFileMiddleware<Context: HBBaseRequestContext>: HBMiddleware {
case .loadFile(let fullPath, let headers, let range):
switch request.method {
case .get:
if let range = range {
if let range {
let (body, _) = try await self.fileIO.loadFile(path: fullPath, range: range, context: context, logger: context.logger)
return HBResponse(status: .partialContent, headers: headers, body: body)
}
Expand Down Expand Up @@ -281,7 +281,7 @@ extension HBFileMiddleware {
}
}

extension Sequence where Element == UInt8 {
extension Sequence<UInt8> {
/// return a hexEncoded string buffer from an array of bytes
func hexDigest() -> String {
return self.map { String(format: "%02x", $0) }.joined(separator: "")
Expand Down
Loading

0 comments on commit 9207774

Please sign in to comment.