Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Throwing RequestContext init from parent RequestContext #596

Merged
merged 9 commits into from
Oct 25, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ internal enum URLEncodedForm {

/// ASCII characters that will not be percent encoded in URL encoded form data
static let unreservedCharacters = CharacterSet(
charactersIn: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~")
charactersIn: "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.~"
)

/// ISO8601 data formatter used throughout URL encoded form code
static var iso8601Formatter: ISO8601DateFormatter {
Expand Down
28 changes: 27 additions & 1 deletion Sources/Hummingbird/Router/RouterMethods.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ extension RouterMethods {
/// - path: path prefix to add to routes inside this group
/// - convertContext: Function converting context
@discardableResult public func group<TargetContext>(
_ path: RouterPath,
_ path: RouterPath = "",
context: TargetContext.Type
) -> RouterGroup<TargetContext> where TargetContext.Source == Context {
return RouterGroup(
Expand All @@ -83,6 +83,32 @@ extension RouterMethods {
)
}

/// Return a group inside the current group that transforms the ``RequestContext``
///
/// For the transform to work the `Source` of the transformed `RequestContext` needs
/// to be the original `RequestContext` eg
/// ```
/// struct TransformedRequestContext: ChildRequestContext {
/// typealias ParentContext = BasicRequestContext
/// var coreContext: CoreRequestContextStorage
/// init(context: ParentContext) throws {
/// self.coreContext = .init(source: source)
/// }
/// }
/// ```
/// - Parameters
/// - path: path prefix to add to routes inside this group
/// - convertContext: Function converting context
@discardableResult public func group<TargetContext: ChildRequestContext>(
_ path: RouterPath = "",
context: TargetContext.Type
) -> RouterGroup<TargetContext> where TargetContext.ParentContext == Context {
return RouterGroup(
path: path,
parent: ThrowingTransformingRouterGroup(parent: self)
)
}

/// Add middleware stack to router
///
/// Add multiple middleware to the router using the middleware stack result builder
Expand Down
50 changes: 47 additions & 3 deletions Sources/Hummingbird/Router/TransformingRouterGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ import HummingbirdCore
import NIOCore

/// Internally used to transform RequestContext
struct TransformingRouterGroup<InputContext: RequestContext, Context: RequestContext>: RouterMethods where Context.Source == InputContext {
struct TransformingRouterGroup<Context: RequestContext, Parent: RouterMethods<Context.Source>>: RouterMethods {
typealias TransformContext = Context
let parent: any RouterMethods<InputContext>
typealias InputContext = Context.Source
let parent: Parent

struct ContextTransformingResponder: HTTPResponder {
typealias Context = InputContext
Expand All @@ -31,7 +32,7 @@ struct TransformingRouterGroup<InputContext: RequestContext, Context: RequestCon
}
}

init(parent: any RouterMethods<InputContext>) {
init(parent: Parent) {
self.parent = parent
}

Expand All @@ -57,3 +58,46 @@ struct TransformingRouterGroup<InputContext: RequestContext, Context: RequestCon
return self
}
}

/// Internally used to transform RequestContext
struct ThrowingTransformingRouterGroup<Context: ChildRequestContext, Parent: RouterMethods<Context.ParentContext>>: RouterMethods {
typealias TransformContext = Context
typealias InputContext = Context.ParentContext
let parent: Parent

struct ContextTransformingResponder: HTTPResponder {
typealias Context = InputContext
let responder: any HTTPResponder<TransformContext>

func respond(to request: Request, context: InputContext) async throws -> Response {
let newContext = try TransformContext(context: context)
return try await self.responder.respond(to: request, context: newContext)
}
}

init(parent: Parent) {
self.parent = parent
}

/// Add middleware (Stub function as it isn't used)
@discardableResult func add(middleware: any MiddlewareProtocol<Request, Response, Context>) -> Self {
preconditionFailure("Cannot add middleware to ThrowingTransformingRouterGroup")
}

/// Add responder to call when path and method are matched
///
/// - Parameters:
/// - path: Path to match
/// - method: Request method to match
/// - responder: Responder to call if match is made
/// - Returns: self
@discardableResult func on<Responder: HTTPResponder>(
_ path: RouterPath,
method: HTTPRequest.Method,
responder: Responder
) -> Self where Responder.Context == Context {
let transformResponder = ContextTransformingResponder(responder: responder)
self.parent.on(path, method: method, responder: transformResponder)
return self
}
}
36 changes: 36 additions & 0 deletions Sources/Hummingbird/Server/ChildRequestContext.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the Hummingbird server framework project
//
// Copyright (c) 2021-2024 the Hummingbird authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See hummingbird/CONTRIBUTORS.txt for the list of Hummingbird authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
import Logging

/// A RequestContext that can be initialized from another RequestContext where initialization
/// can throw errors which will be passed back up the middleware stack
public protocol ChildRequestContext<ParentContext>: RequestContext where Source == Never {
associatedtype ParentContext: RequestContext
/// Initialise RequestContext from source
init(context: ParentContext) throws
}

extension ChildRequestContext {
/// ChildRequestContext can never to created from it Source ``Never`` so add preconditionFailure
public init(source: Source) {
preconditionFailure("Cannot reach this.")
}
}

/// Extend Never to conform to ``RequestContextSource``
extension Never: RequestContextSource {
public var logger: Logger {
preconditionFailure("Cannot reach this.")
}
}
36 changes: 36 additions & 0 deletions Sources/HummingbirdRouter/ContextTransform.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,39 @@ public struct ContextTransform<Context: RouterRequestContext, HandlerContext: Ro
}
}
}

/// Router middleware that transforms the ``Hummingbird/RequestContext`` and uses it with the contained
/// Middleware chain. ``Used by RouteGroup/init(_:context:builder:)``
public struct ThrowingContextTransform<Context: RouterRequestContext, HandlerContext: RouterRequestContext & ChildRequestContext, Handler: MiddlewareProtocol>: RouterMiddleware where Handler.Input == Request, Handler.Output == Response, Handler.Context == HandlerContext, HandlerContext.ParentContext == Context {
public typealias Input = Request
public typealias Output = Response

/// Group handler
@usableFromInline
let handler: Handler

/// Create RouteGroup from result builder
/// - Parameters:
/// - context: RequestContext to convert to
/// - builder: RouteGroup builder
public init(
to context: HandlerContext.Type,
@MiddlewareFixedTypeBuilder<Request, Response, HandlerContext> builder: () -> Handler
) {
self.handler = builder()
}

/// Process HTTP request and return an HTTP response
/// - Parameters:
/// - input: Request
/// - context: Request context
/// - next: Next middleware to run, if no route handler is found
/// - Returns: Response
@inlinable
public func handle(_ input: Input, context: Context, next: (Input, Context) async throws -> Output) async throws -> Output {
let handlerContext = try Handler.Context(context: context)
return try await self.handler.handle(input, context: handlerContext) { input, _ in
try await next(input, context)
}
}
}
36 changes: 36 additions & 0 deletions Sources/HummingbirdRouter/RouteGroup.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,42 @@ public struct RouteGroup<Context: RouterRequestContext, Handler: MiddlewareProto
self.routerPath = routerPath
}

/// Create RouteGroup that transforms the RequestContext from result builder
/// - Parameters:
/// - routerPath: Path local to group route this group is defined in
/// - context: RequestContext to convert to
/// - builder: RouteGroup builder
///
/// The RequestContext that the group uses must conform to ``Hummingbird/ChildRequestContext`` eg
/// ```
/// struct TransformedRequestContext: ChildRequestContext {
/// typealias ParentContext = BasicRequestContext
/// var coreContext: CoreRequestContextStorage
/// init(context: ParentContext) throws {
/// self.coreContext = .init(source: context)
/// }
/// }
/// ```
public init<ChildHandler: MiddlewareProtocol, ChildContext: ChildRequestContext & RouterRequestContext>(
_ routerPath: RouterPath,
context: ChildContext.Type,
@MiddlewareFixedTypeBuilder<Request, Response, ChildContext> builder: () -> ChildHandler
) where ChildContext == ChildContext, Handler == ThrowingContextTransform<Context, ChildContext, ChildHandler> {
var routerPath = routerPath
// Get builder state from service context
var routerBuildState = RouterBuilderState.current ?? .init(options: [])
if routerBuildState.options.contains(.caseInsensitive) {
routerPath = routerPath.lowercased()
}
let parentGroupPath = routerBuildState.routeGroupPath
self.fullPath = parentGroupPath.appendingPath(routerPath)
routerBuildState.routeGroupPath = self.fullPath
self.handler = RouterBuilderState.$current.withValue(routerBuildState) {
ThrowingContextTransform(to: ChildHandler.Context.self, builder: builder)
}
self.routerPath = routerPath
}

/// Process HTTP request and return an HTTP response
/// - Parameters:
/// - input: Request
Expand Down
60 changes: 59 additions & 1 deletion Tests/HummingbirdRouterTests/RouterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ final class RouterTests: XCTestCase {
}
}

/// Test middleware in parent group is applied to routes in child group
/// Test context transform
func testGroupTransformingGroupMiddleware() async throws {
struct TestRouterContext2: RequestContext, RouterRequestContext {
/// router context
Expand Down Expand Up @@ -317,6 +317,64 @@ final class RouterTests: XCTestCase {
}
}

/// Test throwing context transform
func testThrowingTransformingGroupMiddleware() async throws {
struct TestRouterContext: RequestContext, RouterRequestContext {
/// router context
var routerContext: RouterBuilderContext
/// parameters
var coreContext: CoreRequestContextStorage
/// additional data
var string: String?

init(source: Source) {
self.coreContext = .init(source: source)
self.routerContext = .init()
self.string = nil
}
}
struct TestRouterContext2: RequestContext, RouterRequestContext, ChildRequestContext {
/// router context
var routerContext: RouterBuilderContext
/// parameters
var coreContext: CoreRequestContextStorage
/// additional data
var string: String

init(context: TestRouterContext) throws {
self.coreContext = .init(source: context)
self.routerContext = context.routerContext
guard let string = context.string else { throw HTTPError(.badRequest) }
self.string = string
}
}
struct TestTransformMiddleware: RouterMiddleware {
typealias Context = TestRouterContext
func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response {
var context = context
context.string = request.headers[.middleware2]
return try await next(request, context)
}
}
let router = RouterBuilder(context: TestRouterContext.self) {
TestTransformMiddleware()
RouteGroup("/group", context: TestRouterContext2.self) {
Get { _, context in
return Response(status: .ok, headers: [.middleware2: context.string])
}
}
}
let app = Application(responder: router)
try await app.test(.router) { client in
try await client.execute(uri: "/group", method: .get, headers: [.middleware2: "Transforming"]) { response in
XCTAssertEqual(response.headers[.middleware2], "Transforming")
}
try await client.execute(uri: "/group", method: .get) { response in
XCTAssertEqual(response.status, .badRequest)
}
}
}

/// Test adding middleware to group doesn't affect middleware in parent groups
func testRouteBuilder() async throws {
struct TestGroupMiddleware: RouterMiddleware {
Expand Down
52 changes: 52 additions & 0 deletions Tests/HummingbirdTests/RouterTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,58 @@ final class RouterTests: XCTestCase {
}
}

/// Test middleware in parent group is applied to routes in child group
func testThrowingTransformingGroupMiddleware() async throws {
struct TestRouterContext: RequestContext {
init(source: Source) {
self.coreContext = .init(source: source)
self.string = nil
}

/// parameters
var coreContext: CoreRequestContextStorage
/// additional data
var string: String?
}
struct TestRouterContext2: ChildRequestContext {
typealias ParentContext = TestRouterContext
init(context: ParentContext) throws {
self.coreContext = .init(source: context)
guard let string = context.string else { throw HTTPError(.badRequest) }
self.string = string
}

/// parameters
var coreContext: CoreRequestContextStorage
/// additional data
var string: String
}
struct TestTransformMiddleware: RouterMiddleware {
typealias Context = TestRouterContext
func handle(_ request: Request, context: Context, next: (Request, Context) async throws -> Response) async throws -> Response {
var context = context
context.string = request.headers[.test]
return try await next(request, context)
}
}
let router = Router(context: TestRouterContext.self)
router
.add(middleware: TestTransformMiddleware())
.group("/group", context: TestRouterContext2.self)
.get { _, context in
return EditedResponse(headers: [.test: context.string], response: "hello")
}
let app = Application(responder: router.buildResponder())
try await app.test(.router) { client in
try await client.execute(uri: "/group", method: .get, headers: [.test: "test"]) { response in
XCTAssertEqual(response.headers[.test], "test")
}
try await client.execute(uri: "/group", method: .get) { response in
XCTAssertEqual(response.status, .badRequest)
}
}
}

func testParameters() async throws {
let router = Router()
router
Expand Down
Loading