diff --git a/Sources/OpenAISwift/OpenAISwift.swift b/Sources/OpenAISwift/OpenAISwift.swift index 67090ac..3fe6ad6 100644 --- a/Sources/OpenAISwift/OpenAISwift.swift +++ b/Sources/OpenAISwift/OpenAISwift.swift @@ -11,6 +11,8 @@ public enum OpenAIError: Error { } public class OpenAISwift { + fileprivate(set) var token: String? + fileprivate(set) var baseURL: String? fileprivate let config: Config fileprivate let handler = ServerSentEventsHandler() @@ -40,8 +42,10 @@ public class OpenAISwift { } } - public init(config: Config) { - self.config = config + public init(authToken: String, baseURL: String? = nil, config: Config = Config()) { + self.token = authToken + self.baseURL = baseURL + self.config = Config() } } @@ -303,7 +307,7 @@ extension OpenAISwift { } private func prepareRequest(_ endpoint: OpenAIEndpointProvider.API, body: BodyType) -> URLRequest { - var urlComponents = URLComponents(url: URL(string: config.baseURL)!, resolvingAgainstBaseURL: true) + var urlComponents = URLComponents(url: prepareBaseURL(endpoint), resolvingAgainstBaseURL: true) urlComponents?.path = config.endpointProvider.getPath(api: endpoint) var request = URLRequest(url: urlComponents!.url!) request.httpMethod = config.endpointProvider.getMethod(api: endpoint) @@ -319,6 +323,28 @@ extension OpenAISwift { return request } + + private func prepareBaseURL(_ endpoint: Endpoint) -> URL{ + + if var baseURL = baseURL, !baseURL.isEmpty { + // needs https:// for request + if baseURL.lowercased().hasPrefix("http://") { + // change http to https + baseURL = baseURL.replacingOccurrences(of: "http://", with: "https://", options: .caseInsensitive) + } + if !baseURL.lowercased().hasPrefix("https://") { + // add https:// prefix + baseURL = "https://" + baseURL + } + while baseURL.hasSuffix("/") { + baseURL.removeLast() + } + return URL(string: baseURL)! + } + + return URL(string: endpoint.baseURL())! + + } } extension OpenAISwift {