diff --git a/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift b/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift index 9788c5def..eca9a8ea1 100644 --- a/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift +++ b/Sources/protoc-gen-grpc-swift/Generator-Client+AsyncAwait.swift @@ -95,49 +95,79 @@ extension Generator { self.method = method let rpcType = streamingType(self.method) - let callType = Types.call(for: rpcType) - let callTypeWithoutPrefix = Types.call(for: rpcType, withGRPCPrefix: false) + printRpcFunctionImplementation(rpcType: rpcType) + printRpcFunctionWrapper(rpcType: rpcType) + } + } + } - switch rpcType { - case .unary, .serverStreaming: - self.printFunction( - name: self.methodMakeFunctionCallName, - arguments: [ - "_ request: \(self.methodInputName)", - "callOptions: \(Types.clientCallOptions)? = nil", - ], - returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>", - access: self.access - ) { - self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) { - self.println("path: \(self.methodPathUsingClientMetadata),") - self.println("request: request,") - self.println("callOptions: callOptions ?? self.defaultCallOptions,") - self.println( - "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []" - ) - } - } + private func printRpcFunctionImplementation(rpcType: StreamingType) { + let argumentsBuilder: (() -> Void)? + switch rpcType { + case .unary, .serverStreaming: + argumentsBuilder = { + self.println("request: request,") + } + default: + argumentsBuilder = nil + } + let callTypeWithoutPrefix = Types.call(for: rpcType, withGRPCPrefix: false) + printRpcFunction(rpcType: rpcType, name: self.methodMakeFunctionCallName) { + self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) { + self.println("path: \(self.methodPathUsingClientMetadata),") + argumentsBuilder?() + self.println("callOptions: callOptions ?? self.defaultCallOptions,") + self.println( + "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []" + ) + } + } + } - case .clientStreaming, .bidirectionalStreaming: - self.printFunction( - name: self.methodMakeFunctionCallName, - arguments: ["callOptions: \(Types.clientCallOptions)? = nil"], - returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>", - access: self.access - ) { - self.withIndentation("return self.make\(callTypeWithoutPrefix)", braces: .round) { - self.println("path: \(self.methodPathUsingClientMetadata),") - self.println("callOptions: callOptions ?? self.defaultCallOptions,") - self.println( - "interceptors: self.interceptors?.\(self.methodInterceptorFactoryName)() ?? []" - ) - } - } - } + private func printRpcFunctionWrapper(rpcType: StreamingType) { + let functionName = methodMakeFunctionCallName + let functionWrapperName = methodMakeFunctionCallWrapperName + guard functionName != functionWrapperName else { return } + self.println() + + let argumentsBuilder: (() -> Void)? + switch rpcType { + case .unary, .serverStreaming: + argumentsBuilder = { + self.println("request,") + } + default: + argumentsBuilder = nil + } + printRpcFunction(rpcType: rpcType, name: functionWrapperName) { + self.withIndentation("return self.\(functionName)", braces: .round) { + argumentsBuilder?() + self.println("callOptions: callOptions") } } } + + private func printRpcFunction(rpcType: StreamingType, name: String, bodyBuilder: (() -> Void)?) { + let callType = Types.call(for: rpcType) + self.printFunction( + name: name, + arguments: rpcFunctionArguments(rpcType: rpcType), + returnType: "\(callType)<\(self.methodInputName), \(self.methodOutputName)>", + access: self.access, + bodyBuilder: bodyBuilder + ) + } + + private func rpcFunctionArguments(rpcType: StreamingType) -> [String] { + var arguments = ["callOptions: \(Types.clientCallOptions)? = nil"] + switch rpcType { + case .unary, .serverStreaming: + arguments.insert("_ request: \(self.methodInputName)", at: .zero) + default: + break + } + return arguments + } } // MARK: - Client protocol extension: "Simple, but safe" call wrappers. diff --git a/Sources/protoc-gen-grpc-swift/Generator-Names.swift b/Sources/protoc-gen-grpc-swift/Generator-Names.swift index 115693cd4..be6cf2ed5 100644 --- a/Sources/protoc-gen-grpc-swift/Generator-Names.swift +++ b/Sources/protoc-gen-grpc-swift/Generator-Names.swift @@ -129,7 +129,6 @@ extension Generator { internal var methodMakeFunctionCallName: String { let name: String - if self.options.keepMethodCasing { name = self.method.name } else { @@ -140,6 +139,18 @@ extension Generator { return self.sanitize(fieldName: fnName) } + internal var methodMakeFunctionCallWrapperName: String { + return "make\(methodComposableName)Call" + } + + internal var methodComposableName: String { + var name = method.name + if !options.keepMethodCasing { + name = name.prefix(1).uppercased() + name.dropFirst() + } + return name + } + internal func sanitize(fieldName string: String) -> String { if quotableFieldNames.contains(string) { return "`\(string)`"