diff --git a/src/OpenAI.jl b/src/OpenAI.jl index f059fc4..4845d95 100644 --- a/src/OpenAI.jl +++ b/src/OpenAI.jl @@ -21,7 +21,7 @@ end Default provider for OpenAI API requests. """ -const DEFAULT_PROVIDER = let +const DEFAULT_PROVIDER = let api_key = get(ENV, "OPENAI_API_KEY", nothing) if api_key === nothing OpenAIProvider() @@ -157,7 +157,7 @@ function _request(api::AbstractString, provider::AbstractOpenAIProvider, api_key lines = split(body, "\n") # split body into lines # throw out empty lines, skip "data: [DONE] bits - lines = filter(x -> !isempty(x) && !occursin("[DONE]", x), lines) + lines = filter(x -> !isempty(x) && !occursin("[DONE]", x), lines) # read each line, which looks like "data: {}" parsed = map(line -> JSON3.read(line[6:end]), lines) @@ -233,6 +233,57 @@ function create_completion(api_key::String, model_id::String; http_kwargs::Named return openai_request("completions", api_key; method="POST", http_kwargs=http_kwargs, model=model_id, kwargs...) end + +""" +Checking for JSON validity of the response body used in the stream +""" +isvalidjson(str) = + try + JSON3.read(str) + true + catch + false + end + + +""" +Default streamcallback function for create_ functions. +""" +function wordstream(chunk) + + # Regex to match the content value if the string is not valid JSON + regex = r""""content":"(.*?)"(?=[,}])""" + + # Process each line of the string + for line in split(strip(chunk), "\n") + # Remove 'data: ' prefix before checking if the line is JSON + json_str = replace(line, r"^data:\s+" => "") + + if isvalidjson(json_str) + # If it's valid JSON, parse it and extract the content value + json_obj = JSON3.read(json_str) + + if typeof(json_obj) == JSON3.Object{Base.CodeUnits{UInt8,String},Vector{UInt64}} && haskey(json_obj, "choices") + choices = json_obj.choices + if haskey(choices[1], "delta") && haskey(choices[1].delta, "content") + content = choices[1].delta.content + print("$content") + else + print("END") + end + end + else + # Otherwise, apply the regex pattern + m = match(regex, json_str) + if m !== nothing + content = m.captures[1] + print("$content") + end + end + end +end + + """ Create chat @@ -434,7 +485,7 @@ function get_usage_status(provider::OpenAIProvider; numofdays::Int=99) # Get total quota from subscription_url subscription_url = "$base_url/dashboard/billing/subscription" - subscrip = HTTP.get(subscription_url, headers = auth_header(provider)) + subscrip = HTTP.get(subscription_url, headers=auth_header(provider)) resp = OpenAIResponse(subscrip.status, JSON3.read(subscrip.body)) # TODO: catch error quota = resp.response.hard_limit_usd @@ -443,7 +494,7 @@ function get_usage_status(provider::OpenAIProvider; numofdays::Int=99) start_date = today() end_date = today() + Day(numofdays) billing_url = "$base_url/dashboard/billing/usage?start_date=$(start_date)&end_date=$(end_date)" - billing = HTTP.get(billing_url, headers = auth_header(provider)) + billing = HTTP.get(billing_url, headers=auth_header(provider)) resp = OpenAIResponse(billing.status, JSON3.read(billing.body)) usage = resp.response.total_usage / 100 daily_costs = resp.response.daily_costs