From 3836ef8d5ec7552441c6db3dbcce9e8d9f819f99 Mon Sep 17 00:00:00 2001 From: Simon Mitchell Date: Tue, 28 Nov 2023 16:35:04 +0000 Subject: [PATCH 1/3] Refactors OpenAISwift slightly to allow not only for custom endpoint providers but also custom request handlers/creation. Moves Config to URLSessionRequestHandler. --- .../OpenAISwift/OpenAIRequestHandler.swift | 35 +++++ Sources/OpenAISwift/OpenAISwift.swift | 101 +++--------- .../URLSessionRequestHandler.swift | 144 ++++++++++++++++++ .../OpenAISwift/ServerSentEventsHandler.swift | 65 -------- 4 files changed, 202 insertions(+), 143 deletions(-) create mode 100644 Sources/OpenAISwift/OpenAIRequestHandler.swift create mode 100644 Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift delete mode 100644 Sources/OpenAISwift/ServerSentEventsHandler.swift diff --git a/Sources/OpenAISwift/OpenAIRequestHandler.swift b/Sources/OpenAISwift/OpenAIRequestHandler.swift new file mode 100644 index 0000000..2e760ac --- /dev/null +++ b/Sources/OpenAISwift/OpenAIRequestHandler.swift @@ -0,0 +1,35 @@ +// +// OpenAIRequestHandler.swift +// LumenateApp +// +// Created by Simon Mitchell on 28/11/2023. +// + +import Foundation + +public protocol OpenAIRequestHandler { + + /// Function which performs the request as required from the user. + /// - Note: However the request is made, it must do a few things + /// 1. Call `completionHandler` with any errors or response data + /// 2. Data returned must be decodable to the model types defined in this library + /// - Parameters: + /// - endpoint: The endpoint to make a request to + /// - body: The body of the request to make + /// - completionHandler: A closure to be called once the request has completed + func makeRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, completionHandler: @escaping (Result) -> Void) + + /// Function which streams the request as required by the user. + /// - Note: Only "chat" api is streamable for now, so this always has return type of `StreamMessageResult` + /// - Parameters: + /// - endpoint: The endpoint to stream the request from. Note: currently this is only for "chat" endpoint + /// - body: The body of the request to make + /// - eventReceived: Called Multiple times, returns an OpenAI Data Model + /// - completion: Triggers when sever complete sending the message + func streamRequest( + _ endpoint: OpenAIEndpointProvider.API, + body: BodyType, + eventReceived: ((Result, OpenAIError>) -> Void)?, + completion: (() -> Void)? + ) +} diff --git a/Sources/OpenAISwift/OpenAISwift.swift b/Sources/OpenAISwift/OpenAISwift.swift index 67090ac..68d788a 100644 --- a/Sources/OpenAISwift/OpenAISwift.swift +++ b/Sources/OpenAISwift/OpenAISwift.swift @@ -4,6 +4,10 @@ import FoundationNetworking import FoundationXML #endif +// Typealias for backward compatibility so allowing custom request makers +// doesn't introduce breaking changes to the public API +public typealias Config = URLSessionRequestHandler + public enum OpenAIError: Error { case genericError(error: Error) case decodingError(error: Error) @@ -11,37 +15,18 @@ public enum OpenAIError: Error { } public class OpenAISwift { - fileprivate let config: Config - fileprivate let handler = ServerSentEventsHandler() - - /// Configuration object for the client - public struct Config { - - /// Initialiser - /// - Parameter session: the session to use for network requests. - public init(baseURL: String, endpointPrivider: OpenAIEndpointProvider, session: URLSession, authorizeRequest: @escaping (inout URLRequest) -> Void) { - self.baseURL = baseURL - self.endpointProvider = endpointPrivider - self.authorizeRequest = authorizeRequest - self.session = session - } - let baseURL: String - let endpointProvider: OpenAIEndpointProvider - let session:URLSession - let authorizeRequest: (inout URLRequest) -> Void - - public static func makeDefaultOpenAI(apiKey: String) -> Self { - .init(baseURL: "https://api.openai.com", - endpointPrivider: OpenAIEndpointProvider(source: .openAI), - session: .shared, - authorizeRequest: { request in - request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") - }) - } + fileprivate let requestHandler: OpenAIRequestHandler + + /// Initialises OpenAISwift with a given request handler + /// - Parameter requestHandler: The request handler to make requests with + public init(requestHandler: OpenAIRequestHandler) { + self.requestHandler = requestHandler } + /// Deprecated initialiser for backwards API support to remove breaking change when introducing OpenAIRequestHandler protocol + /// - Parameter config: The config to initialise with public init(config: Config) { - self.config = config + self.requestHandler = config } } @@ -55,9 +40,8 @@ extension OpenAISwift { public func sendCompletion(with prompt: String, model: OpenAIModelType = .gpt3(.davinci), maxTokens: Int = 16, temperature: Double = 1, completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.completions let body = Command(prompt: prompt, model: model.modelName, maxTokens: maxTokens, temperature: temperature) - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -81,9 +65,8 @@ extension OpenAISwift { public func sendEdits(with instruction: String, model: OpenAIModelType = .feature(.davinci), input: String = "", completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.edits let body = Instruction(instruction: instruction, model: model.modelName, input: input) - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -106,9 +89,8 @@ extension OpenAISwift { public func sendModerations(with input: String, model: OpenAIModelType = .moderation(.latest), completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.moderations let body = Moderation(input: input, model: model.modelName) - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -162,10 +144,8 @@ extension OpenAISwift { frequencyPenalty: frequencyPenalty, logitBias: logitBias, stream: false) - - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): if let chatErr = try? JSONDecoder().decode(ChatError.self, from: success) as ChatError { @@ -197,9 +177,8 @@ extension OpenAISwift { let endpoint = OpenAIEndpointProvider.API.embeddings let body = EmbeddingsInput(input: input, model: model.modelName) - - let request = prepareRequest(endpoint, body: body) - makeRequest(request: request) { result in + + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -255,10 +234,8 @@ extension OpenAISwift { frequencyPenalty: frequencyPenalty, logitBias: logitBias, stream: true) - let request = prepareRequest(endpoint, body: body) - handler.onEventReceived = onEventReceived - handler.onComplete = onComplete - handler.connect(with: request) + + requestHandler.streamRequest(endpoint, body: body, eventReceived: onEventReceived, completion: onComplete) } @@ -272,9 +249,8 @@ extension OpenAISwift { public func sendImages(with prompt: String, numImages: Int = 1, size: ImageSize = .size1024, user: String? = nil, completionHandler: @escaping (Result, OpenAIError>) -> Void) { let endpoint = OpenAIEndpointProvider.API.images let body = ImageGeneration(prompt: prompt, n: numImages, size: size, user: user) - let request = prepareRequest(endpoint, body: body) - - makeRequest(request: request) { result in + + requestHandler.makeRequest(endpoint, body: body) { result in switch result { case .success(let success): do { @@ -288,37 +264,6 @@ extension OpenAISwift { } } } - - private func makeRequest(request: URLRequest, completionHandler: @escaping (Result) -> Void) { - let session = config.session - let task = session.dataTask(with: request) { (data, response, error) in - if let error = error { - completionHandler(.failure(error)) - } else if let data = data { - completionHandler(.success(data)) - } - } - - task.resume() - } - - private func prepareRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest { - var urlComponents = URLComponents(url: URL(string: config.baseURL)!, resolvingAgainstBaseURL: true) - urlComponents?.path = config.endpointProvider.getPath(api: endpoint) - var request = URLRequest(url: urlComponents!.url!) - request.httpMethod = config.endpointProvider.getMethod(api: endpoint) - - config.authorizeRequest(&request) - - request.setValue("application/json", forHTTPHeaderField: "content-type") - - let encoder = JSONEncoder() - if let encoded = try? encoder.encode(body) { - request.httpBody = encoded - } - - return request - } } extension OpenAISwift { diff --git a/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift b/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift new file mode 100644 index 0000000..a2eaf3c --- /dev/null +++ b/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift @@ -0,0 +1,144 @@ +// +// URLSessionRequestHandler.swift +// LumenateApp +// +// Created by Simon Mitchell on 28/11/2023. +// + +import Foundation + +public final class URLSessionRequestHandler: NSObject, OpenAIRequestHandler { + + let baseURL: String + + let endpointProvider: OpenAIEndpointProvider + + let session: URLSession + + let authorizeRequest: (inout URLRequest) -> Void + + var onEventReceived: ((Result, OpenAIError>) -> Void)? + + var onComplete: (() -> Void)? + + private lazy var streamingSession: URLSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil) + + private var streamingTask: URLSessionDataTask? + + /// Default memberwise initialiser + /// - Parameters: + /// - baseURL: The base url to load data from + /// - endpointPrivider: An endpoint provider for generating full urls for each request + /// - session: The session to use for network requests + /// - authorizeRequest: A closure to authenticate a specific `URLRequest` + public init(baseURL: String, endpointPrivider: OpenAIEndpointProvider, session: URLSession, authorizeRequest: @escaping (inout URLRequest) -> Void) { + self.session = session + self.endpointProvider = endpointPrivider + self.authorizeRequest = authorizeRequest + self.baseURL = baseURL + } + + // MARK: Protocol Conformance + + public func makeRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, completionHandler: @escaping (Result) -> Void) where BodyType : Encodable { + let request = prepareRequest(endpoint, body: body) + makeRequest(request: request, completionHandler: completionHandler) + } + + public func streamRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, eventReceived: ((Result, OpenAIError>) -> Void)?, completion: (() -> Void)?) where BodyType : Encodable { + + let request = prepareRequest(endpoint, body: body) + self.onEventReceived = eventReceived + self.onComplete = completion + connect(with: request) + } + + private func makeRequest(request: URLRequest, completionHandler: @escaping (Result) -> Void) { + let task = session.dataTask(with: request) { (data, response, error) in + if let error = error { + completionHandler(.failure(error)) + } else if let data = data { + completionHandler(.success(data)) + } + } + + task.resume() + } + + private func prepareRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest { + var urlComponents = URLComponents(url: URL(string: baseURL)!, resolvingAgainstBaseURL: true) + urlComponents?.path = endpointProvider.getPath(api: endpoint) + var request = URLRequest(url: urlComponents!.url!) + request.httpMethod = endpointProvider.getMethod(api: endpoint) + + authorizeRequest(&request) + + request.setValue("application/json", forHTTPHeaderField: "content-type") + + let encoder = JSONEncoder() + if let encoded = try? encoder.encode(body) { + request.httpBody = encoded + } + + return request + } + + private func connect(with request: URLRequest) { + streamingTask = session.dataTask(with: request) + streamingTask?.resume() + } + + fileprivate func disconnect() { + streamingTask?.cancel() + } + + fileprivate func processEvent(_ eventData: Data) { + do { + let res = try JSONDecoder().decode(OpenAI.self, from: eventData) + onEventReceived?(.success(res)) + } catch { + onEventReceived?(.failure(.decodingError(error: error))) + } + } +} + +extension URLSessionRequestHandler: URLSessionDataDelegate { + /// It will be called several times, each time could return one chunk of data or multiple chunk of data + /// The JSON look liks this: + /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}` + /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Once"},"index":0,"finish_reason":null}]}` + public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { + if let eventString = String(data: data, encoding: .utf8) { + let lines = eventString.split(separator: "\n") + for line in lines { + if line.hasPrefix("data:") && line != "data: [DONE]" { + if let eventData = String(line.dropFirst(5)).data(using: .utf8) { + processEvent(eventData) + } else { + disconnect() + } + } + } + } + } + + public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { + if let error = error { + onEventReceived?(.failure(.genericError(error: error))) + } else { + onComplete?() + } + } +} + +public extension Config { + + static func makeDefaultOpenAI(apiKey: String) -> URLSessionRequestHandler { + return URLSessionRequestHandler(baseURL: "https://api.openai.com", + endpointPrivider: OpenAIEndpointProvider(source: .openAI), + session: .shared, + authorizeRequest: { request in + request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization") + }) + } +} diff --git a/Sources/OpenAISwift/ServerSentEventsHandler.swift b/Sources/OpenAISwift/ServerSentEventsHandler.swift deleted file mode 100644 index e3f92db..0000000 --- a/Sources/OpenAISwift/ServerSentEventsHandler.swift +++ /dev/null @@ -1,65 +0,0 @@ -// -// ServerSentEventsHandler.swift -// -// -// Created by Vic on 2023-03-25. -// - -import Foundation - -class ServerSentEventsHandler: NSObject { - - var onEventReceived: ((Result, OpenAIError>) -> Void)? - var onComplete: (() -> Void)? - - private lazy var session: URLSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil) - private var task: URLSessionDataTask? - - func connect(with request: URLRequest) { - task = session.dataTask(with: request) - task?.resume() - } - - func disconnect() { - task?.cancel() - } - - func processEvent(_ eventData: Data) { - do { - let res = try JSONDecoder().decode(OpenAI.self, from: eventData) - onEventReceived?(.success(res)) - } catch { - onEventReceived?(.failure(.decodingError(error: error))) - } - } -} - -extension ServerSentEventsHandler: URLSessionDataDelegate { - - /// It will be called several times, each time could return one chunk of data or multiple chunk of data - /// The JSON look liks this: - /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}` - /// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Once"},"index":0,"finish_reason":null}]}` - func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) { - if let eventString = String(data: data, encoding: .utf8) { - let lines = eventString.split(separator: "\n") - for line in lines { - if line.hasPrefix("data:") && line != "data: [DONE]" { - if let eventData = String(line.dropFirst(5)).data(using: .utf8) { - processEvent(eventData) - } else { - disconnect() - } - } - } - } - } - - func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { - if let error = error { - onEventReceived?(.failure(.genericError(error: error))) - } else { - onComplete?() - } - } -} From 11447dafe70916f978000c77229d9d77594ac616 Mon Sep 17 00:00:00 2001 From: Simon Mitchell Date: Tue, 28 Nov 2023 16:36:18 +0000 Subject: [PATCH 2/3] Pulls in CodingKey fix for message error since `id` added --- Sources/OpenAISwift/Models/ChatMessage.swift | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Sources/OpenAISwift/Models/ChatMessage.swift b/Sources/OpenAISwift/Models/ChatMessage.swift index c987cf9..9cd3daf 100644 --- a/Sources/OpenAISwift/Models/ChatMessage.swift +++ b/Sources/OpenAISwift/Models/ChatMessage.swift @@ -34,6 +34,10 @@ public struct ChatMessage: Codable, Identifiable { self.role = role self.content = content } + + private enum CodingKeys: String, CodingKey { + case role, content + } } /// A structure that represents a chat conversation. From 005b3b6664f20cea9c4dc74a982a418fe44fd91d Mon Sep 17 00:00:00 2001 From: Simon Mitchell Date: Tue, 28 Nov 2023 16:50:30 +0000 Subject: [PATCH 3/3] Moves Config typealias back to OpenAISwift context --- Sources/OpenAISwift/OpenAISwift.swift | 9 +++++---- .../Request Handlers/URLSessionRequestHandler.swift | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/Sources/OpenAISwift/OpenAISwift.swift b/Sources/OpenAISwift/OpenAISwift.swift index 68d788a..e5a1a8d 100644 --- a/Sources/OpenAISwift/OpenAISwift.swift +++ b/Sources/OpenAISwift/OpenAISwift.swift @@ -4,10 +4,6 @@ import FoundationNetworking import FoundationXML #endif -// Typealias for backward compatibility so allowing custom request makers -// doesn't introduce breaking changes to the public API -public typealias Config = URLSessionRequestHandler - public enum OpenAIError: Error { case genericError(error: Error) case decodingError(error: Error) @@ -15,6 +11,11 @@ public enum OpenAIError: Error { } public class OpenAISwift { + + // Typealias for backward compatibility so allowing custom request makers + // doesn't introduce breaking changes to the public API + public typealias Config = URLSessionRequestHandler + fileprivate let requestHandler: OpenAIRequestHandler /// Initialises OpenAISwift with a given request handler diff --git a/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift b/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift index a2eaf3c..5269ac0 100644 --- a/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift +++ b/Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift @@ -131,7 +131,7 @@ extension URLSessionRequestHandler: URLSessionDataDelegate { } } -public extension Config { +public extension OpenAISwift.Config { static func makeDefaultOpenAI(apiKey: String) -> URLSessionRequestHandler { return URLSessionRequestHandler(baseURL: "https://api.openai.com",