Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Formatter (spec + CI step) #56

Merged
merged 4 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# See https://domluna.github.io/JuliaFormatter.jl/stable/ for a list of options
style = "sciml"
4 changes: 4 additions & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,7 @@ jobs:
- run: cd $GITHUB_WORKSPACE
- name: Run tests
run: julia --project=. -e "using Pkg; Pkg.instantiate(); Pkg.activate(\".\"); Pkg.test()"
- name: Install JuliaFormatter and format
run: |
julia -e 'import Pkg; Pkg.add("JuliaFormatter")'
julia -e 'using JuliaFormatter; out=format(".", verbose=true); out ? exit(0) : exit(1)'
21 changes: 21 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,24 @@

All contributions are welcome, just open a PR and let's talk!
As this project matures, I'll come up with interface and style guidelines.

**Table of Contents**

- [Contributing](#contributing)
- [Pull Request Checklist](#pull-request-checklist)
- [Style Guide](#style-guide)


### Pull Request Checklist
- [ ] Add tests for new features
- [ ] Add documentation for new features
- [ ] Add a line to the `CHANGELOG.md` file describing the new feature
- [ ] All tests pass
- [ ] The code is formatted with `JuliaFormatter` (see [Style Guide](#style-guide) below)

### Style Guide
This repository follows SciML's [Style](https://github.com/SciML/SciMLStyle). The exact flavor is described in the file `./JuliaFormatter.toml`.

Before opening a PR, please run `using JuliaFormatter; format(".")` in the root directory of the repository to make sure that your contributions are formatted accordingly.

If your contribution is not formatted, the CI will fail and you will be asked to format your code before the PR can be merged.
6 changes: 2 additions & 4 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
using Documenter, OpenAI

makedocs(sitename="OpenAI.jl Documentation")
makedocs(sitename = "OpenAI.jl Documentation")

deploydocs(
repo = "github.com/JuliaML/OpenAI.jl.git",
)
deploydocs(repo = "github.com/JuliaML/OpenAI.jl.git")
194 changes: 133 additions & 61 deletions src/OpenAI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using JSON3
using HTTP
using Dates


abstract type AbstractOpenAIProvider end
Base.@kwdef struct OpenAIProvider <: AbstractOpenAIProvider
api_key::String = ""
Expand All @@ -27,7 +26,7 @@ const DEFAULT_PROVIDER = let
if api_key === nothing
OpenAIProvider()
else
OpenAIProvider(api_key=api_key)
OpenAIProvider(api_key = api_key)
end
end

Expand All @@ -41,14 +40,14 @@ function auth_header(::OpenAIProvider, api_key::AbstractString)
isempty(api_key) && throw(ArgumentError("api_key cannot be empty"))
[
"Authorization" => "Bearer $api_key",
"Content-Type" => "application/json"
"Content-Type" => "application/json",
]
end
function auth_header(::AzureProvider, api_key::AbstractString)
isempty(api_key) && throw(ArgumentError("api_key cannot be empty"))
[
"api-key" => api_key,
"Content-Type" => "application/json"
"Content-Type" => "application/json",
]
end

Expand Down Expand Up @@ -80,14 +79,12 @@ function request_body(url, method; input, headers, query, kwargs...)
input = isnothing(input) ? [] : input
query = isnothing(query) ? [] : query

resp = HTTP.request(
method,
resp = HTTP.request(method,
url;
body=input,
query=query,
headers=headers,
kwargs...
)
body = input,
query = query,
headers = headers,
kwargs...)
return resp, resp.body
end

Expand All @@ -96,7 +93,6 @@ function request_body_live(url; method, input, headers, streamcallback, kwargs..

body = sprint() do output
resp = HTTP.open("POST", url, headers) do stream

body = String(take!(input))
write(stream, body)

Expand Down Expand Up @@ -137,25 +133,23 @@ function request_body_live(url; method, input, headers, streamcallback, kwargs..
return resp, body
end

function status_error(resp, log=nothing)
function status_error(resp, log = nothing)
logs = !isnothing(log) ? ": $log" : ""
error("request status $(resp.message)$logs")
end

function _request(
api::AbstractString,
function _request(api::AbstractString,
provider::AbstractOpenAIProvider,
api_key::AbstractString=provider.api_key;
api_key::AbstractString = provider.api_key;
method,
query=nothing,
query = nothing,
http_kwargs,
streamcallback=nothing,
additional_headers::AbstractVector=Pair{String,String}[],
kwargs...
)
streamcallback = nothing,
additional_headers::AbstractVector = Pair{String, String}[],
kwargs...)
# add stream: True to the API call if a stream callback function is passed
if !isnothing(streamcallback)
kwargs = (kwargs..., stream=true)
kwargs = (kwargs..., stream = true)
end

params = build_params(kwargs)
Expand All @@ -165,24 +159,20 @@ function _request(
headers = vcat(auth_header(provider, api_key), additional_headers)

if isnothing(streamcallback)
request_body(
url,
request_body(url,
method;
input=params,
headers=headers,
query=query,
http_kwargs...
)
input = params,
headers = headers,
query = query,
http_kwargs...)
else
request_body_live(
url;
request_body_live(url;
method,
input=params,
headers=headers,
query=query,
streamcallback=streamcallback,
http_kwargs...
)
input = params,
headers = headers,
query = query,
streamcallback = streamcallback,
http_kwargs...)
end
end
if resp.status >= 400
Expand All @@ -202,17 +192,32 @@ function _request(

OpenAIResponse(resp.status, parsed)
end

end
end

function openai_request(api::AbstractString, api_key::AbstractString; method, http_kwargs, streamcallback=nothing, kwargs...)
function openai_request(api::AbstractString,
api_key::AbstractString;
method,
http_kwargs,
streamcallback = nothing,
kwargs...)
global DEFAULT_PROVIDER
_request(api, DEFAULT_PROVIDER, api_key; method, http_kwargs, streamcallback=streamcallback, kwargs...)
_request(api,
DEFAULT_PROVIDER,
api_key;
method,
http_kwargs,
streamcallback = streamcallback,
kwargs...)
end

function openai_request(api::AbstractString, provider::AbstractOpenAIProvider; method, http_kwargs, streamcallback=nothing, kwargs...)
_request(api, provider; method, http_kwargs, streamcallback=streamcallback, kwargs...)
function openai_request(api::AbstractString,
provider::AbstractOpenAIProvider;
method,
http_kwargs,
streamcallback = nothing,
kwargs...)
_request(api, provider; method, http_kwargs, streamcallback = streamcallback, kwargs...)
end

struct OpenAIResponse{R}
Expand All @@ -234,8 +239,8 @@ List models

For additional details, visit <https://platform.openai.com/docs/api-reference/models/list>
"""
function list_models(api_key::String; http_kwargs::NamedTuple=NamedTuple())
return openai_request("models", api_key; method="GET", http_kwargs=http_kwargs)
function list_models(api_key::String; http_kwargs::NamedTuple = NamedTuple())
return openai_request("models", api_key; method = "GET", http_kwargs = http_kwargs)
end

"""
Expand All @@ -247,8 +252,13 @@ Retrieve model

For additional details, visit <https://platform.openai.com/docs/api-reference/models/retrieve>
"""
function retrieve_model(api_key::String, model_id::String; http_kwargs::NamedTuple=NamedTuple())
return openai_request("models/$(model_id)", api_key; method="GET", http_kwargs=http_kwargs)
function retrieve_model(api_key::String,
model_id::String;
http_kwargs::NamedTuple = NamedTuple())
return openai_request("models/$(model_id)",
api_key;
method = "GET",
http_kwargs = http_kwargs)
end

"""
Expand All @@ -267,8 +277,16 @@ For more details about the endpoint and additional arguments, visit <https://pla
# HTTP.request keyword arguments:
- `http_kwargs::NamedTuple=NamedTuple()`: Keyword arguments to pass to HTTP.request (e. g., `http_kwargs=(connection_timeout=2,)` to set a connection timeout of 2 seconds).
"""
function create_completion(api_key::String, model_id::String; http_kwargs::NamedTuple=NamedTuple(), kwargs...)
return openai_request("completions", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, kwargs...)
function create_completion(api_key::String,
model_id::String;
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
return openai_request("completions",
api_key;
method = "POST",
http_kwargs = http_kwargs,
model = model_id,
kwargs...)
end

"""
Expand Down Expand Up @@ -341,11 +359,35 @@ julia> map(r->r["choices"][1]["delta"], CC.response)
{}
```
"""
function create_chat(api_key::String, model_id::String, messages; http_kwargs::NamedTuple=NamedTuple(), streamcallback=nothing, kwargs...)
return openai_request("chat/completions", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, messages=messages, streamcallback=streamcallback, kwargs...)
function create_chat(api_key::String,
model_id::String,
messages;
http_kwargs::NamedTuple = NamedTuple(),
streamcallback = nothing,
kwargs...)
return openai_request("chat/completions",
api_key;
method = "POST",
http_kwargs = http_kwargs,
model = model_id,
messages = messages,
streamcallback = streamcallback,
kwargs...)
end
function create_chat(provider::AbstractOpenAIProvider, model_id::String, messages; http_kwargs::NamedTuple=NamedTuple(), streamcallback=nothing, kwargs...)
return openai_request("chat/completions", provider; method="POST", http_kwargs=http_kwargs, model=model_id, messages=messages, streamcallback=streamcallback, kwargs...)
function create_chat(provider::AbstractOpenAIProvider,
model_id::String,
messages;
http_kwargs::NamedTuple = NamedTuple(),
streamcallback = nothing,
kwargs...)
return openai_request("chat/completions",
provider;
method = "POST",
http_kwargs = http_kwargs,
model = model_id,
messages = messages,
streamcallback = streamcallback,
kwargs...)
end

"""
Expand All @@ -363,8 +405,18 @@ Create edit

For additional details about the endpoint, visit <https://platform.openai.com/docs/api-reference/edits>
"""
function create_edit(api_key::String, model_id::String, instruction::String; http_kwargs::NamedTuple=NamedTuple(), kwargs...)
return openai_request("edits", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, instruction, kwargs...)
function create_edit(api_key::String,
model_id::String,
instruction::String;
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
return openai_request("edits",
api_key;
method = "POST",
http_kwargs = http_kwargs,
model = model_id,
instruction,
kwargs...)
end

"""
Expand All @@ -382,8 +434,18 @@ Create embeddings

For additional details about the endpoint, visit <https://platform.openai.com/docs/api-reference/embeddings>
"""
function create_embeddings(api_key::String, input, model_id::String=DEFAULT_EMBEDDING_MODEL_ID; http_kwargs::NamedTuple=NamedTuple(), kwargs...)
return openai_request("embeddings", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, input, kwargs...)
function create_embeddings(api_key::String,
input,
model_id::String = DEFAULT_EMBEDDING_MODEL_ID;
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
return openai_request("embeddings",
api_key;
method = "POST",
http_kwargs = http_kwargs,
model = model_id,
input,
kwargs...)
end

"""
Expand All @@ -405,8 +467,18 @@ For additional details about the endpoint, visit <https://platform.openai.com/do
download like this:
`download(r.response["data"][begin]["url"], "image.png")`
"""
function create_images(api_key::String, prompt, n::Integer=1, size::String="256x256"; http_kwargs::NamedTuple=NamedTuple(), kwargs...)
return openai_request("images/generations", api_key; method="POST", http_kwargs=http_kwargs, prompt, kwargs...)
function create_images(api_key::String,
prompt,
n::Integer = 1,
size::String = "256x256";
http_kwargs::NamedTuple = NamedTuple(),
kwargs...)
return openai_request("images/generations",
api_key;
method = "POST",
http_kwargs = http_kwargs,
prompt,
kwargs...)
end

"""
Expand Down Expand Up @@ -463,14 +535,14 @@ Each element of `daily_costs` looks like this:
}
```
"""
function get_usage_status(provider::OpenAIProvider; numofdays::Int=99)
function get_usage_status(provider::OpenAIProvider; numofdays::Int = 99)
(; base_url, api_key) = provider
isempty(api_key) && throw(ArgumentError("api_key cannot be empty"))
numofdays > 99 && throw(ArgumentError("numofdays cannot be greater than 99"))

# Get total quota from subscription_url
subscription_url = "$base_url/dashboard/billing/subscription"
subscrip = HTTP.get(subscription_url, headers=auth_header(provider))
subscrip = HTTP.get(subscription_url, headers = auth_header(provider))
resp = OpenAIResponse(subscrip.status, JSON3.read(subscrip.body))
# TODO: catch error
quota = resp.response.hard_limit_usd
Expand All @@ -479,7 +551,7 @@ function get_usage_status(provider::OpenAIProvider; numofdays::Int=99)
start_date = today()
end_date = today() + Day(numofdays)
billing_url = "$base_url/dashboard/billing/usage?start_date=$(start_date)&end_date=$(end_date)"
billing = HTTP.get(billing_url, headers=auth_header(provider))
billing = HTTP.get(billing_url, headers = auth_header(provider))
resp = OpenAIResponse(billing.status, JSON3.read(billing.body))
usage = resp.response.total_usage / 100
daily_costs = resp.response.daily_costs
Expand Down
Loading