Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Sources/OpenAISwift/Models/ChatMessage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ public struct ChatMessage: Codable, Identifiable {
self.role = role
self.content = content
}

private enum CodingKeys: String, CodingKey {
case role, content
}
}

/// A structure that represents a chat conversation.
Expand Down
35 changes: 35 additions & 0 deletions Sources/OpenAISwift/OpenAIRequestHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//
// OpenAIRequestHandler.swift
// LumenateApp
//
// Created by Simon Mitchell on 28/11/2023.
//

import Foundation

public protocol OpenAIRequestHandler {

/// Function which performs the request as required from the user.
/// - Note: However the request is made, it must do a few things
/// 1. Call `completionHandler` with any errors or response data
/// 2. Data returned must be decodable to the model types defined in this library
/// - Parameters:
/// - endpoint: The endpoint to make a request to
/// - body: The body of the request to make
/// - completionHandler: A closure to be called once the request has completed
func makeRequest<BodyType: Encodable>(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, completionHandler: @escaping (Result<Data, Error>) -> Void)

/// Function which streams the request as required by the user.
/// - Note: Only "chat" api is streamable for now, so this always has return type of `StreamMessageResult`
/// - Parameters:
/// - endpoint: The endpoint to stream the request from. Note: currently this is only for "chat" endpoint
/// - body: The body of the request to make
/// - eventReceived: Called Multiple times, returns an OpenAI Data Model
/// - completion: Triggers when sever complete sending the message
func streamRequest<BodyType: Encodable>(
_ endpoint: OpenAIEndpointProvider.API,
body: BodyType,
eventReceived: ((Result<OpenAI<StreamMessageResult>, OpenAIError>) -> Void)?,
completion: (() -> Void)?
)
}
102 changes: 24 additions & 78 deletions Sources/OpenAISwift/OpenAISwift.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,23 @@ public enum OpenAIError: Error {
}

public class OpenAISwift {
fileprivate let config: Config
fileprivate let handler = ServerSentEventsHandler()

/// Configuration object for the client
public struct Config {

/// Initialiser
/// - Parameter session: the session to use for network requests.
public init(baseURL: String, endpointPrivider: OpenAIEndpointProvider, session: URLSession, authorizeRequest: @escaping (inout URLRequest) -> Void) {
self.baseURL = baseURL
self.endpointProvider = endpointPrivider
self.authorizeRequest = authorizeRequest
self.session = session
}
let baseURL: String
let endpointProvider: OpenAIEndpointProvider
let session:URLSession
let authorizeRequest: (inout URLRequest) -> Void

public static func makeDefaultOpenAI(apiKey: String) -> Self {
.init(baseURL: "https://api.openai.com",
endpointPrivider: OpenAIEndpointProvider(source: .openAI),
session: .shared,
authorizeRequest: { request in
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
})
}

// Typealias for backward compatibility so allowing custom request makers
// doesn't introduce breaking changes to the public API
public typealias Config = URLSessionRequestHandler

fileprivate let requestHandler: OpenAIRequestHandler

/// Initialises OpenAISwift with a given request handler
/// - Parameter requestHandler: The request handler to make requests with
public init(requestHandler: OpenAIRequestHandler) {
self.requestHandler = requestHandler
}

/// Deprecated initialiser for backwards API support to remove breaking change when introducing OpenAIRequestHandler protocol
/// - Parameter config: The config to initialise with
public init(config: Config) {
self.config = config
self.requestHandler = config
}
}

Expand All @@ -55,9 +41,8 @@ extension OpenAISwift {
public func sendCompletion(with prompt: String, model: OpenAIModelType = .gpt3(.davinci), maxTokens: Int = 16, temperature: Double = 1, completionHandler: @escaping (Result<OpenAI<TextResult>, OpenAIError>) -> Void) {
let endpoint = OpenAIEndpointProvider.API.completions
let body = Command(prompt: prompt, model: model.modelName, maxTokens: maxTokens, temperature: temperature)
let request = prepareRequest(endpoint, body: body)

makeRequest(request: request) { result in
requestHandler.makeRequest(endpoint, body: body) { result in
switch result {
case .success(let success):
do {
Expand All @@ -81,9 +66,8 @@ extension OpenAISwift {
public func sendEdits(with instruction: String, model: OpenAIModelType = .feature(.davinci), input: String = "", completionHandler: @escaping (Result<OpenAI<TextResult>, OpenAIError>) -> Void) {
let endpoint = OpenAIEndpointProvider.API.edits
let body = Instruction(instruction: instruction, model: model.modelName, input: input)
let request = prepareRequest(endpoint, body: body)

makeRequest(request: request) { result in
requestHandler.makeRequest(endpoint, body: body) { result in
switch result {
case .success(let success):
do {
Expand All @@ -106,9 +90,8 @@ extension OpenAISwift {
public func sendModerations(with input: String, model: OpenAIModelType = .moderation(.latest), completionHandler: @escaping (Result<OpenAI<ModerationResult>, OpenAIError>) -> Void) {
let endpoint = OpenAIEndpointProvider.API.moderations
let body = Moderation(input: input, model: model.modelName)
let request = prepareRequest(endpoint, body: body)

makeRequest(request: request) { result in
requestHandler.makeRequest(endpoint, body: body) { result in
switch result {
case .success(let success):
do {
Expand Down Expand Up @@ -162,10 +145,8 @@ extension OpenAISwift {
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
stream: false)

let request = prepareRequest(endpoint, body: body)

makeRequest(request: request) { result in
requestHandler.makeRequest(endpoint, body: body) { result in
switch result {
case .success(let success):
if let chatErr = try? JSONDecoder().decode(ChatError.self, from: success) as ChatError {
Expand Down Expand Up @@ -197,9 +178,8 @@ extension OpenAISwift {
let endpoint = OpenAIEndpointProvider.API.embeddings
let body = EmbeddingsInput(input: input,
model: model.modelName)

let request = prepareRequest(endpoint, body: body)
makeRequest(request: request) { result in

requestHandler.makeRequest(endpoint, body: body) { result in
switch result {
case .success(let success):
do {
Expand Down Expand Up @@ -255,10 +235,8 @@ extension OpenAISwift {
frequencyPenalty: frequencyPenalty,
logitBias: logitBias,
stream: true)
let request = prepareRequest(endpoint, body: body)
handler.onEventReceived = onEventReceived
handler.onComplete = onComplete
handler.connect(with: request)

requestHandler.streamRequest(endpoint, body: body, eventReceived: onEventReceived, completion: onComplete)
}


Expand All @@ -272,9 +250,8 @@ extension OpenAISwift {
public func sendImages(with prompt: String, numImages: Int = 1, size: ImageSize = .size1024, user: String? = nil, completionHandler: @escaping (Result<OpenAI<UrlResult>, OpenAIError>) -> Void) {
let endpoint = OpenAIEndpointProvider.API.images
let body = ImageGeneration(prompt: prompt, n: numImages, size: size, user: user)
let request = prepareRequest(endpoint, body: body)

makeRequest(request: request) { result in

requestHandler.makeRequest(endpoint, body: body) { result in
switch result {
case .success(let success):
do {
Expand All @@ -288,37 +265,6 @@ extension OpenAISwift {
}
}
}

private func makeRequest(request: URLRequest, completionHandler: @escaping (Result<Data, Error>) -> Void) {
let session = config.session
let task = session.dataTask(with: request) { (data, response, error) in
if let error = error {
completionHandler(.failure(error))
} else if let data = data {
completionHandler(.success(data))
}
}

task.resume()
}

private func prepareRequest<BodyType: Encodable>(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest {
var urlComponents = URLComponents(url: URL(string: config.baseURL)!, resolvingAgainstBaseURL: true)
urlComponents?.path = config.endpointProvider.getPath(api: endpoint)
var request = URLRequest(url: urlComponents!.url!)
request.httpMethod = config.endpointProvider.getMethod(api: endpoint)

config.authorizeRequest(&request)

request.setValue("application/json", forHTTPHeaderField: "content-type")

let encoder = JSONEncoder()
if let encoded = try? encoder.encode(body) {
request.httpBody = encoded
}

return request
}
}

extension OpenAISwift {
Expand Down
144 changes: 144 additions & 0 deletions Sources/OpenAISwift/Request Handlers/URLSessionRequestHandler.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
//
// URLSessionRequestHandler.swift
// LumenateApp
//
// Created by Simon Mitchell on 28/11/2023.
//

import Foundation

public final class URLSessionRequestHandler: NSObject, OpenAIRequestHandler {

let baseURL: String

let endpointProvider: OpenAIEndpointProvider

let session: URLSession

let authorizeRequest: (inout URLRequest) -> Void

var onEventReceived: ((Result<OpenAI<StreamMessageResult>, OpenAIError>) -> Void)?

var onComplete: (() -> Void)?

private lazy var streamingSession: URLSession = URLSession(configuration: .default, delegate: self, delegateQueue: nil)

private var streamingTask: URLSessionDataTask?

/// Default memberwise initialiser
/// - Parameters:
/// - baseURL: The base url to load data from
/// - endpointPrivider: An endpoint provider for generating full urls for each request
/// - session: The session to use for network requests
/// - authorizeRequest: A closure to authenticate a specific `URLRequest`
public init(baseURL: String, endpointPrivider: OpenAIEndpointProvider, session: URLSession, authorizeRequest: @escaping (inout URLRequest) -> Void) {
self.session = session
self.endpointProvider = endpointPrivider
self.authorizeRequest = authorizeRequest
self.baseURL = baseURL
}

// MARK: Protocol Conformance

public func makeRequest<BodyType>(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, completionHandler: @escaping (Result<Data, Error>) -> Void) where BodyType : Encodable {
let request = prepareRequest(endpoint, body: body)
makeRequest(request: request, completionHandler: completionHandler)
}

public func streamRequest<BodyType>(_ endpoint: OpenAIEndpointProvider.API, body: BodyType, eventReceived: ((Result<OpenAI<StreamMessageResult>, OpenAIError>) -> Void)?, completion: (() -> Void)?) where BodyType : Encodable {

let request = prepareRequest(endpoint, body: body)
self.onEventReceived = eventReceived
self.onComplete = completion
connect(with: request)
}

private func makeRequest(request: URLRequest, completionHandler: @escaping (Result<Data, Error>) -> Void) {
let task = session.dataTask(with: request) { (data, response, error) in
if let error = error {
completionHandler(.failure(error))
} else if let data = data {
completionHandler(.success(data))
}
}

task.resume()
}

private func prepareRequest<BodyType: Encodable>(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest {
var urlComponents = URLComponents(url: URL(string: baseURL)!, resolvingAgainstBaseURL: true)
urlComponents?.path = endpointProvider.getPath(api: endpoint)
var request = URLRequest(url: urlComponents!.url!)
request.httpMethod = endpointProvider.getMethod(api: endpoint)

authorizeRequest(&request)

request.setValue("application/json", forHTTPHeaderField: "content-type")

let encoder = JSONEncoder()
if let encoded = try? encoder.encode(body) {
request.httpBody = encoded
}

return request
}

private func connect(with request: URLRequest) {
streamingTask = session.dataTask(with: request)
streamingTask?.resume()
}

fileprivate func disconnect() {
streamingTask?.cancel()
}

fileprivate func processEvent(_ eventData: Data) {
do {
let res = try JSONDecoder().decode(OpenAI<StreamMessageResult>.self, from: eventData)
onEventReceived?(.success(res))
} catch {
onEventReceived?(.failure(.decodingError(error: error)))
}
}
}

extension URLSessionRequestHandler: URLSessionDataDelegate {
/// It will be called several times, each time could return one chunk of data or multiple chunk of data
/// The JSON look liks this:
/// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"role":"assistant"},"index":0,"finish_reason":null}]}`
/// `data: {"id":"chatcmpl-6yVTvD6UAXsE9uG2SmW4Tc2iuFnnT","object":"chat.completion.chunk","created":1679878715,"model":"gpt-3.5-turbo-0301","choices":[{"delta":{"content":"Once"},"index":0,"finish_reason":null}]}`
public func urlSession(_ session: URLSession, dataTask: URLSessionDataTask, didReceive data: Data) {
if let eventString = String(data: data, encoding: .utf8) {
let lines = eventString.split(separator: "\n")
for line in lines {
if line.hasPrefix("data:") && line != "data: [DONE]" {
if let eventData = String(line.dropFirst(5)).data(using: .utf8) {
processEvent(eventData)
} else {
disconnect()
}
}
}
}
}

public func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) {
if let error = error {
onEventReceived?(.failure(.genericError(error: error)))
} else {
onComplete?()
}
}
}

public extension OpenAISwift.Config {

static func makeDefaultOpenAI(apiKey: String) -> URLSessionRequestHandler {
return URLSessionRequestHandler(baseURL: "https://api.openai.com",
endpointPrivider: OpenAIEndpointProvider(source: .openAI),
session: .shared,
authorizeRequest: { request in
request.setValue("Bearer \(apiKey)", forHTTPHeaderField: "Authorization")
})
}
}
Loading