Skip to content

Commit

Permalink
feat(image): support response_type in the OpenAI API request (#2347)
Browse files Browse the repository at this point in the history
* Change response_format type to string to match OpenAI Spec

Signed-off-by: prajwal <[email protected]>

* updated response_type type to interface

Signed-off-by: prajwal <[email protected]>

* feat: correctly parse generic struct

Signed-off-by: mudler <[email protected]>

* add tests

Signed-off-by: mudler <[email protected]>

---------

Signed-off-by: prajwal <[email protected]>
Signed-off-by: mudler <[email protected]>
Co-authored-by: Ettore Di Giacinto <[email protected]>
Co-authored-by: mudler <[email protected]>
  • Loading branch information
3 people authored May 29, 2024
1 parent 087bcec commit 4d98dd9
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 13 deletions.
8 changes: 5 additions & 3 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@ type BackendConfig struct {
Backend string `yaml:"backend"`
TemplateConfig TemplateConfig `yaml:"template"`

PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`
PromptStrings, InputStrings []string `yaml:"-"`
InputToken [][]int `yaml:"-"`
functionCallString, functionCallNameString string `yaml:"-"`
ResponseFormat string `yaml:"-"`
ResponseFormatMap map[string]interface{} `yaml:"-"`

FunctionsConfig functions.FunctionsConfig `yaml:"function"`

Expand Down
9 changes: 7 additions & 2 deletions core/http/endpoints/openai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,13 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup
noActionDescription = config.FunctionsConfig.NoActionDescriptionName
}

if input.ResponseFormat.Type == "json_object" {
input.Grammar = functions.JSONBNF
if config.ResponseFormatMap != nil {
d := schema.ChatCompletionResponseFormat{}
dat, _ := json.Marshal(config.ResponseFormatMap)
_ = json.Unmarshal(dat, &d)
if d.Type == "json_object" {
input.Grammar = functions.JSONBNF
}
}

config.Grammar = input.Grammar
Expand Down
9 changes: 7 additions & 2 deletions core/http/endpoints/openai/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,13 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
return fmt.Errorf("failed reading parameters from request:%w", err)
}

if input.ResponseFormat.Type == "json_object" {
input.Grammar = functions.JSONBNF
if config.ResponseFormatMap != nil {
d := schema.ChatCompletionResponseFormat{}
dat, _ := json.Marshal(config.ResponseFormatMap)
_ = json.Unmarshal(dat, &d)
if d.Type == "json_object" {
input.Grammar = functions.JSONBNF
}
}

config.Grammar = input.Grammar
Expand Down
6 changes: 2 additions & 4 deletions core/http/endpoints/openai/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,8 @@ func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appCon
return fmt.Errorf("invalid value for 'size'")
}

b64JSON := false
if input.ResponseFormat.Type == "b64_json" {
b64JSON = true
}
b64JSON := config.ResponseFormat == "b64_json"

// src and clip_skip
var result []schema.Item
for _, i := range config.PromptStrings {
Expand Down
9 changes: 9 additions & 0 deletions core/http/endpoints/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
config.Maxtokens = input.Maxtokens
}

if input.ResponseFormat != nil {
switch responseFormat := input.ResponseFormat.(type) {
case string:
config.ResponseFormat = responseFormat
case map[string]interface{}:
config.ResponseFormatMap = responseFormat
}
}

switch stop := input.Stop.(type) {
case string:
if stop != "" {
Expand Down
4 changes: 3 additions & 1 deletion core/schema/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ type OpenAIModel struct {
Object string `json:"object"`
}

type ImageGenerationResponseFormat string

type ChatCompletionResponseFormatType string

type ChatCompletionResponseFormat struct {
Expand All @@ -114,7 +116,7 @@ type OpenAIRequest struct {
// whisper
File string `json:"file" validate:"required"`
//whisper/image
ResponseFormat ChatCompletionResponseFormat `json:"response_format"`
ResponseFormat interface{} `json:"response_format,omitempty"`
// image
Size string `json:"size"`
// Prompt is read only by completion/image API calls
Expand Down
25 changes: 24 additions & 1 deletion tests/e2e-aio/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,36 @@ var _ = Describe("E2E test", func() {
openai.ImageRequest{
Prompt: "test",
Size: openai.CreateImageSize512x512,
//ResponseFormat: openai.CreateImageResponseFormatURL,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
})
It("correctly changes the response format to url", func() {
resp, err := client.CreateImage(context.TODO(),
openai.ImageRequest{
Prompt: "test",
Size: openai.CreateImageSize512x512,
ResponseFormat: openai.CreateImageResponseFormatURL,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].URL).To(ContainSubstring("png"), fmt.Sprint(resp.Data[0].URL))
})
It("correctly changes the response format to base64", func() {
resp, err := client.CreateImage(context.TODO(),
openai.ImageRequest{
Prompt: "test",
Size: openai.CreateImageSize512x512,
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
},
)
Expect(err).ToNot(HaveOccurred())
Expect(len(resp.Data)).To(Equal(1), fmt.Sprint(resp))
Expect(resp.Data[0].B64JSON).ToNot(BeEmpty(), fmt.Sprint(resp.Data[0].B64JSON))
})
})
Context("embeddings", func() {
It("correctly", func() {
Expand Down

0 comments on commit 4d98dd9

Please sign in to comment.