Skip to content

Commit

Permalink
feat(templates): use a single template for multimodals messages (mudl…
Browse files Browse the repository at this point in the history
…er#3892)

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler authored Oct 22, 2024
1 parent a1d6cc9 commit ccc7cb0
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 29 deletions.
4 changes: 1 addition & 3 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ type TemplateConfig struct {
// It defaults to \n
JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"`

Video string `yaml:"video"`
Image string `yaml:"image"`
Audio string `yaml:"audio"`
Multimodal string `yaml:"multimodal"`
}

func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error {
Expand Down
43 changes: 23 additions & 20 deletions core/http/endpoints/openai/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,18 +149,27 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
// Decode each request's message content
imgIndex, vidIndex, audioIndex := 0, 0, 0
for i, m := range input.Messages {
nrOfImgsInMessage := 0
nrOfVideosInMessage := 0
nrOfAudiosInMessage := 0

switch content := m.Content.(type) {
case string:
input.Messages[i].StringContent = content
case []interface{}:
dat, _ := json.Marshal(content)
c := []schema.Content{}
json.Unmarshal(dat, &c)

textContent := ""
// we will template this at the end

CONTENT:
for _, pp := range c {
switch pp.Type {
case "text":
input.Messages[i].StringContent = pp.Text
textContent += pp.Text
//input.Messages[i].StringContent = pp.Text
case "video", "video_url":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL)
Expand All @@ -169,14 +178,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
continue CONTENT
}
input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff

t := "[vid-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Video != "" {
t = config.TemplateConfig.Video
}
// set a placeholder for each image
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent)
vidIndex++
nrOfVideosInMessage++
case "audio_url", "audio":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL)
Expand All @@ -185,13 +188,8 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
continue CONTENT
}
input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
t := "[audio-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Audio != "" {
t = config.TemplateConfig.Audio
}
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent)
audioIndex++
nrOfAudiosInMessage++
case "image_url", "image":
// Decode content as base64 either if it's an URL or base64 text
base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL)
Expand All @@ -200,16 +198,21 @@ func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIReque
continue CONTENT
}

t := "[img-{{.ID}}]{{.Text}}"
if config.TemplateConfig.Image != "" {
t = config.TemplateConfig.Image
}
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
// set a placeholder for each image
input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent)

imgIndex++
nrOfImgsInMessage++
}
}

input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{
TotalImages: imgIndex,
TotalVideos: vidIndex,
TotalAudios: audioIndex,
ImagesInMessage: nrOfImgsInMessage,
VideosInMessage: nrOfVideosInMessage,
AudiosInMessage: nrOfAudiosInMessage,
}, textContent)
}
}

Expand Down
50 changes: 45 additions & 5 deletions pkg/templates/multimodal.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,60 @@ import (
"github.com/Masterminds/sprig/v3"
)

func TemplateMultiModal(templateString string, templateID int, text string) (string, error) {
type MultiModalOptions struct {
TotalImages int
TotalAudios int
TotalVideos int

ImagesInMessage int
AudiosInMessage int
VideosInMessage int
}

type MultimodalContent struct {
ID int
}

const DefaultMultiModalTemplate = "{{ range .Audio }}[audio-{{.ID}}]{{end}}{{ range .Images }}[img-{{.ID}}]{{end}}{{ range .Video }}[vid-{{.ID}}]{{end}}{{.Text}}"

func TemplateMultiModal(templateString string, opts MultiModalOptions, text string) (string, error) {
if templateString == "" {
templateString = DefaultMultiModalTemplate
}

// compile the template
tmpl, err := template.New("template").Funcs(sprig.FuncMap()).Parse(templateString)
if err != nil {
return "", err
}

videos := []MultimodalContent{}
for i := 0; i < opts.VideosInMessage; i++ {
videos = append(videos, MultimodalContent{ID: i + (opts.TotalVideos - opts.VideosInMessage)})
}

audios := []MultimodalContent{}
for i := 0; i < opts.AudiosInMessage; i++ {
audios = append(audios, MultimodalContent{ID: i + (opts.TotalAudios - opts.AudiosInMessage)})
}

images := []MultimodalContent{}
for i := 0; i < opts.ImagesInMessage; i++ {
images = append(images, MultimodalContent{ID: i + (opts.TotalImages - opts.ImagesInMessage)})
}

result := bytes.NewBuffer(nil)
// execute the template
err = tmpl.Execute(result, struct {
ID int
Text string
Audio []MultimodalContent
Images []MultimodalContent
Video []MultimodalContent
Text string
}{
ID: templateID,
Text: text,
Audio: audios,
Images: images,
Video: videos,
Text: text,
})
return result.String(), err
}
72 changes: 71 additions & 1 deletion pkg/templates/multimodal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,77 @@ import (
var _ = Describe("EvaluateTemplate", func() {
Context("templating simple strings for multimodal chat", func() {
It("should template messages correctly", func() {
result, err := TemplateMultiModal("[img-{{.ID}}]{{.Text}}", 1, "bar")
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 1,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 1,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[img-0]bar"))
})

It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 2,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 2,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[img-0][img-1]bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 4,
TotalAudios: 1,
TotalVideos: 0,
ImagesInMessage: 2,
AudiosInMessage: 1,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[audio-0][img-2][img-3]bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 3,
TotalAudios: 1,
TotalVideos: 0,
ImagesInMessage: 1,
AudiosInMessage: 1,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[audio-0][img-2]bar"))
})
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("", MultiModalOptions{
TotalImages: 0,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 0,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("bar"))
})
})
Context("templating with custom defaults", func() {
It("should handle messages with more images correctly", func() {
result, err := TemplateMultiModal("{{ range .Audio }}[audio-{{ add1 .ID}}]{{end}}{{ range .Images }}[img-{{ add1 .ID}}]{{end}}{{ range .Video }}[vid-{{ add1 .ID}}]{{end}}{{.Text}}", MultiModalOptions{
TotalImages: 1,
TotalAudios: 0,
TotalVideos: 0,
ImagesInMessage: 1,
AudiosInMessage: 0,
VideosInMessage: 0,
}, "bar")
Expect(err).NotTo(HaveOccurred())
Expect(result).To(Equal("[img-1]bar"))
})
Expand Down

0 comments on commit ccc7cb0

Please sign in to comment.