Skip to content

Commit

Permalink
Added support for audio and chat endpoints (#7)
Browse files Browse the repository at this point in the history
* Added chat package

* Added chat_test.go

* Closes #5

* Updated README

* Stubbed out Audio package

* audio endpoint implementation

* Fixed audio endpoint implementation and added tests

* Closes #6

* Added relative path specifier to example audio files

* Revert "Added relative path specifier to example audio files"

This reverts commit 8020c00.
  • Loading branch information
Kardbord authored Apr 8, 2023
1 parent 35c4918 commit 97b38e0
Show file tree
Hide file tree
Showing 14 changed files with 511 additions and 2 deletions.
6 changes: 4 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ examples/**/.env
.vscode/

# Example binaries
dist/
examples/audio/audio
examples/chat/chat
examples/completions/completions
examples/edits/edits
examples/embeddings/embeddings
examples/files/files
examples/finetunes/finetunes
examples/images/images
examples/models/models
examples/moderations/moderations
dist/
examples/moderations/moderations
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ Go (Golang) bindings for the [OpenAI API](https://beta.openai.com/docs/api-refer

The links below lead to examples of how to use each library package.

- [x] [Audio](./audio/README.md)
- [x] [Chat](./chat/README.md)
- [x] [Completions](./completions/README.md)
- [x] [Edits](./edits/README.md)
- [x] [Embeddings](./embeddings/README.md)
Expand Down
7 changes: 7 additions & 0 deletions audio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Audio

Bindings for the [audio](https://platform.openai.com/docs/api-reference/audio) [endpoint](https://api.openai.com/v1/audio/transcriptions).

## Example

See [audio-example.go](../examples/audio/audio-example.go).
154 changes: 154 additions & 0 deletions audio/audio.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Package audio provides bindings for the [audio] [endpoint].
// Converts audio into text.
//
// [endpoint]: https://api.openai.com/v1/audio/transcriptions
//
// [chat]: https://platform.openai.com/docs/api-reference/audio
package audio

import (
"bytes"
"errors"
"mime/multipart"
"net/http"
"path/filepath"

"github.com/TannerKvarfordt/gopenai/common"
)

const (
BaseEndpoint = common.BaseURL + "audio/"
TransciptionEndpoint = BaseEndpoint + "transcriptions"
TranslationEndpoint = BaseEndpoint + "translations"
)

type ResponseFormat = string

const (
// TODO: Support non-json return formats.
JSONResponseFormat = "json"
//TextResponseFormat = "text"
//SRTResponseFormat = "srt"
//VerboseJSONResponseFormat = "verbose_json"
//VTTResponseFormat = "vtt"
)

// Request structure for the transcription endpoint.
type TranscriptionRequest struct {
// The audio file to transcribe, in one of these formats:
// mp3, mp4, mpeg, mpga, m4a, wav, or webm.
// This can be a file path or a URL.
File string `json:"file"`

// ID of the model to use. You can use the List models API
// to see all of your available models, or see our Model
// overview for descriptions of them.
Model string `json:"model"`

// An optional text to guide the model's style or continue a
// previous audio segment. The prompt should match the audio language.
Prompt string `json:"prompt,omitempty"`

// The format of the transcript output, in one of these options:
// json, text, srt, verbose_json, or vtt.
ResponseFormat ResponseFormat `json:"response_format,omitempty"`

// The sampling temperature, between 0 and 1. Higher values like 0.8 will
// make the output more random, while lower values like 0.2 will make it
// more focused and deterministic. If set to 0, the model will use log
// probability to automatically increase the temperature until certain
// thresholds are hit.
Temperature *float64 `json:"temperature,omitempty"`

// The language of the input audio. Supplying the input language in
// ISO-639-1 format will improve accuracy and latency.
Language string `json:"language,omitempty"`
}

// Request structure for the Translations endpoint.
type TranslationRequest struct {
// The audio file to transcribe, in one of these formats:
// mp3, mp4, mpeg, mpga, m4a, wav, or webm.
// This can be a file path or a URL.
File string `json:"file"`

// ID of the model to use. You can use the List models API
// to see all of your available models, or see our Model
// overview for descriptions of them.
Model string `json:"model"`

// An optional text to guide the model's style or continue a
// previous audio segment. The prompt should be in English.
Prompt string `json:"prompt,omitempty"`

// The format of the transcript output, in one of these options:
// json, text, srt, verbose_json, or vtt.
ResponseFormat ResponseFormat `json:"response_format,omitempty"`

// The sampling temperature, between 0 and 1. Higher values like 0.8 will
// make the output more random, while lower values like 0.2 will make it
// more focused and deterministic. If set to 0, the model will use log
// probability to automatically increase the temperature until certain
// thresholds are hit.
Temperature *float64 `json:"temperature,omitempty"`
}

// Response structure for both Transcription and
// Translation requests.
type Response struct {
Text string `json:"text"`
Usage common.ResponseUsage `json:"usage"`
Error *common.ResponseError `json:"error,omitempty"`
}

func MakeTranscriptionRequest(request *TranscriptionRequest, organizationID *string) (*Response, error) {
buf := new(bytes.Buffer)
writer := multipart.NewWriter(buf)
err := common.CreateFormField("model", request.Model, writer)
if err != nil {
return nil, err
}

err = common.CreateFormFile("file", filepath.Base(request.File), request.File, writer)
if err != nil {
return nil, err
}
writer.Close()
r, err := common.MakeRequestWithForm[Response](buf, TransciptionEndpoint, http.MethodPost, writer.FormDataContentType(), organizationID)
if err != nil {
return nil, err
}
if r == nil {
return nil, errors.New("nil response received")
}
if r.Error != nil {
return r, r.Error
}
return r, nil
}

func MakeTranslationRequest(request *TranslationRequest, organizationID *string) (*Response, error) {
buf := new(bytes.Buffer)
writer := multipart.NewWriter(buf)
err := common.CreateFormField("model", request.Model, writer)
if err != nil {
return nil, err
}

err = common.CreateFormFile("file", filepath.Base(request.File), request.File, writer)
if err != nil {
return nil, err
}
writer.Close()
r, err := common.MakeRequestWithForm[Response](buf, TranslationEndpoint, http.MethodPost, writer.FormDataContentType(), organizationID)
if err != nil {
return nil, err
}
if r == nil {
return nil, errors.New("nil response received")
}
if r.Error != nil {
return r, r.Error
}
return r, nil
}
51 changes: 51 additions & 0 deletions audio/audio_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package audio_test

import (
"os"
"testing"

"github.com/TannerKvarfordt/gopenai/audio"
"github.com/TannerKvarfordt/gopenai/authentication"
)

const (
OpenAITokenEnv = "OPENAI_API_KEY"
transcriptionFilePath = "./test_files/transcription.m4a"
translationFilePath = "./test_files/translation.m4a"
model = "whisper-1"
)

func init() {
key := os.Getenv(OpenAITokenEnv)
authentication.SetAPIKey(key)
}

func TestTranscription(t *testing.T) {
resp, err := audio.MakeTranscriptionRequest(&audio.TranscriptionRequest{
File: transcriptionFilePath,
Model: model,
}, nil)
if err != nil {
t.Fatal(err)
return
}
if resp.Text == "" {
t.Fatal("no transcription returned")
return
}
}

func TestTranslation(t *testing.T) {
resp, err := audio.MakeTranslationRequest(&audio.TranslationRequest{
File: translationFilePath,
Model: model,
}, nil)
if err != nil {
t.Fatal(err)
return
}
if resp.Text == "" {
t.Fatal("no translation returned")
return
}
}
Binary file added audio/test_files/transcription.m4a
Binary file not shown.
Binary file added audio/test_files/translation.m4a
Binary file not shown.
7 changes: 7 additions & 0 deletions chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Chat

Bindings for the [chat](https://platform.openai.com/docs/api-reference/chat) [endpoint](https://api.openai.com/v1/chat/completions).

## Example

See [chat-example.go](../examples/chat/chat-example.go).
139 changes: 139 additions & 0 deletions chat/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
// Package chat provides bindings for the [chat] [endpoint].
// Given a chat conversation, the model will return a chat
// completion response.
//
// [chat]: https://platform.openai.com/docs/api-reference/chat
// [endpoint]: https://api.openai.com/v1/chat/completions
package chat

import (
"errors"
"net/http"

"github.com/TannerKvarfordt/gopenai/common"
"github.com/TannerKvarfordt/gopenai/moderations"
)

const Endpoint = common.BaseURL + "chat/completions"

type Role = string

const (
SystemRole Role = "system"
UserRole Role = "user"
AssistantRole Role = "assistant"
)

type Chat struct {
Role Role `json:"role"`
Content string `json:"content"`
}

// Request structure for the chat API endpoint.
type Request struct {
// ID of the model to use. You can use the List models API
// to see all of your available models, or see our Model
// overview for descriptions of them.
Model string `json:"model"`

// The messages to generate chat completions for,
// in the [chat format].
//
// [chat format]: https://platform.openai.com/docs/guides/chat
Messages []Chat `json:"messages"`

// What sampling temperature to use, between 0 and 2. Higher values
// like 0.8 will make the output more random, while lower values like
// 0.2 will make it more focused and deterministic. We generally
// recommend altering this or top_p but not both.
Temperature *float64 `json:"temperature,omitempty"`

// An alternative to sampling with temperature, called nucleus sampling,
// where the model considers the results of the tokens with top_p
// probability mass. So 0.1 means only the tokens comprising the top 10%
// probability mass are considered.
// We generally recommend altering this or temperature but not both.
TopP *float64 `json:"top_p,omitempty"`

// How many chat completion choices to generate for each input message.
N *int64 `json:"n,omitempty"`

// If set, partial message deltas will be sent, like in ChatGPT. Tokens
// will be sent as data-only server-sent events as they become available,
// with the stream terminated by a data: [DONE] message. See the OpenAI
// Cookbook for example code.
// Stream bool `json:"stream,omitempty"` TODO: Add streaming support

// Up to 4 sequences where the API will stop generating further tokens.
Stop []string `json:"stop,omitempty"`

// The maximum number of tokens to generate in the chat completion.
// The total length of input tokens and generated tokens is limited
// by the model's context length.
MaxTokens *int64 `json:"max_tokens,omitempty"`

// Number between -2.0 and 2.0. Positive values penalize new tokens
// based on their existing frequency in the text so far, decreasing
// the model's likelihood to repeat the same line verbatim.
PresencePenalty *float64 `json:"presence_penalty,omitempty"`

// Modify the likelihood of specified tokens appearing in the completion.
// Accepts a json object that maps tokens (specified by their token ID in
// the tokenizer) to an associated bias value from -100 to 100. Mathematically,
// the bias is added to the logits generated by the model prior to sampling.
// The exact effect will vary per model, but values between -1 and 1 should decrease
// or increase likelihood of selection; values like -100 or 100 should result in a
// ban or exclusive selection of the relevant token.
LogitBias map[string]int64 `json:"logit_bias,omitempty"`

// A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
User string `json:"user,omitempty"`
}

type Response struct {
ID string `json:"id,omitempty"`
Object string `json:"object,omitempty"`
Created int64 `json:"created,omitempty"`
Choices []struct {
Index int64 `json:"index,omitempty"`
Message Chat `json:"message,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
}
Usage common.ResponseUsage `json:"usage"`
Error *common.ResponseError `json:"error,omitempty"`
}

func MakeRequest(request *Request, organizationID *string) (*Response, error) {
r, err := common.MakeRequest[Request, Response](request, Endpoint, http.MethodPost, organizationID)
if err != nil {
return nil, err
}
if r == nil {
return nil, errors.New("nil response received")
}
if r.Error != nil {
return r, r.Error
}
if len(r.Choices) == 0 {
return r, errors.New("no choices in response")
}
return r, nil
}

func MakeModeratedRequest(request *Request, organizationID *string) (*Response, *moderations.Response, error) {
input := make([]string, len(request.Messages))
for i := range request.Messages {
input[i] = request.Messages[i].Content
}

modr, err := moderations.MakeModeratedRequest(&moderations.Request{
Input: input,
Model: moderations.ModelLatest,
}, organizationID)
if err != nil {
return nil, modr, err
}

r, err := MakeRequest(request, organizationID)
return r, modr, err
}
Loading

0 comments on commit 97b38e0

Please sign in to comment.