Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added the isjsonvalid() and the wordstream() functions #52

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions src/OpenAI.jl
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ end
Default provider for OpenAI API requests.
"""
const DEFAULT_PROVIDER = let
const DEFAULT_PROVIDER = let
api_key = get(ENV, "OPENAI_API_KEY", nothing)
if api_key === nothing
OpenAIProvider()
@@ -157,7 +157,7 @@ function _request(api::AbstractString, provider::AbstractOpenAIProvider, api_key
lines = split(body, "\n") # split body into lines

# throw out empty lines, skip "data: [DONE] bits
lines = filter(x -> !isempty(x) && !occursin("[DONE]", x), lines)
lines = filter(x -> !isempty(x) && !occursin("[DONE]", x), lines)

# read each line, which looks like "data: {<json elements>}"
parsed = map(line -> JSON3.read(line[6:end]), lines)
@@ -233,6 +233,57 @@ function create_completion(api_key::String, model_id::String; http_kwargs::Named
return openai_request("completions", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, kwargs...)
end


"""
Checking for JSON validity of the response body used in the stream
"""
isvalidjson(str) =
try
JSON3.read(str)
true
catch
false
end
Comment on lines +240 to +246
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor note but I typically prefer function ... end syntax for multiline functions.

Suggested change
isvalidjson(str) =
try
JSON3.read(str)
true
catch
false
end
function isvalidjson(str)
try
JSON3.read(str)
true
catch
false
end
end



"""
Default streamcallback function for create_<action> functions.
"""
Comment on lines +249 to +251
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a little usage section here to show how people should use this?

function wordstream(chunk)

# Regex to match the content value if the string is not valid JSON
regex = r""""content":"(.*?)"(?=[,}])"""

# Process each line of the string
for line in split(strip(chunk), "\n")
# Remove 'data: ' prefix before checking if the line is JSON
json_str = replace(line, r"^data:\s+" => "")

if isvalidjson(json_str)
# If it's valid JSON, parse it and extract the content value
json_obj = JSON3.read(json_str)

if typeof(json_obj) == JSON3.Object{Base.CodeUnits{UInt8,String},Vector{UInt64}} && haskey(json_obj, "choices")
choices = json_obj.choices
if haskey(choices[1], "delta") && haskey(choices[1].delta, "content")
content = choices[1].delta.content
print("$content")
else
print("END")
end
end
else
# Otherwise, apply the regex pattern
m = match(regex, json_str)
if m !== nothing
content = m.captures[1]
print("$content")
end
end
end
end


"""
Create chat
@@ -434,7 +485,7 @@ function get_usage_status(provider::OpenAIProvider; numofdays::Int=99)

# Get total quota from subscription_url
subscription_url = "$base_url/dashboard/billing/subscription"
subscrip = HTTP.get(subscription_url, headers = auth_header(provider))
subscrip = HTTP.get(subscription_url, headers=auth_header(provider))
resp = OpenAIResponse(subscrip.status, JSON3.read(subscrip.body))
# TODO: catch error
quota = resp.response.hard_limit_usd
@@ -443,7 +494,7 @@ function get_usage_status(provider::OpenAIProvider; numofdays::Int=99)
start_date = today()
end_date = today() + Day(numofdays)
billing_url = "$base_url/dashboard/billing/usage?start_date=$(start_date)&end_date=$(end_date)"
billing = HTTP.get(billing_url, headers = auth_header(provider))
billing = HTTP.get(billing_url, headers=auth_header(provider))
resp = OpenAIResponse(billing.status, JSON3.read(billing.body))
usage = resp.response.total_usage / 100
daily_costs = resp.response.daily_costs