From 85ed9ae2d7ce5a103ec86e4a6ec2e1add65c5b12 Mon Sep 17 00:00:00 2001 From: Dave Lee Date: Wed, 23 Oct 2024 22:46:46 -0400 Subject: [PATCH] centralized request middleware Signed-off-by: Dave Lee --- core/backend/llm.go | 6 +- core/backend/rerank.go | 4 +- core/backend/soundgeneration.go | 5 +- core/backend/tokenize.go | 6 +- core/backend/tts.go | 41 +- core/cli/soundgeneration.go | 3 +- core/cli/tts.go | 4 +- core/config/backend_config.go | 31 +- core/config/backend_config_loader.go | 30 +- core/config/guesser.go | 9 +- core/http/app.go | 11 +- core/http/ctx/fiber.go | 47 -- .../endpoints/elevenlabs/soundgeneration.go | 42 +- core/http/endpoints/elevenlabs/tts.go | 36 +- core/http/endpoints/jina/rerank.go | 52 +- .../endpoints/localai/get_token_metrics.go | 15 +- core/http/endpoints/localai/tokenize.go | 41 +- core/http/endpoints/localai/tts.go | 58 +- core/http/endpoints/openai/chat.go | 26 +- core/http/endpoints/openai/completion.go | 25 +- core/http/endpoints/openai/edit.go | 18 +- core/http/endpoints/openai/embeddings.go | 14 +- core/http/endpoints/openai/image.go | 21 +- core/http/endpoints/openai/inference.go | 2 +- core/http/endpoints/openai/transcription.go | 16 +- .../openai => middleware}/request.go | 759 ++++++++++-------- core/http/routes/elevenlabs.go | 13 +- core/http/routes/jina.go | 8 +- core/http/routes/localai.go | 16 +- core/http/routes/openai.go | 73 +- core/schema/elevenlabs.go | 19 +- core/schema/jina.go | 3 +- core/schema/localai.go | 14 +- core/schema/prediction.go | 2 +- core/schema/request.go | 22 + core/schema/tokenize.go | 2 +- core/services/list_models.go | 12 + core/startup/startup.go | 8 +- .../docs/getting-started/quickstart.md | 13 + go.mod | 2 +- go.sum | 2 + pkg/model/initializers.go | 2 +- pkg/model/loader_options.go | 8 + tests/e2e-aio/e2e_test.go | 7 +- 44 files changed, 834 insertions(+), 714 deletions(-) delete mode 100644 core/http/ctx/fiber.go rename core/http/{endpoints/openai => middleware}/request.go (50%) create mode 100644 core/schema/request.go diff --git a/core/backend/llm.go b/core/backend/llm.go index 199a62338c8..78d44f2f8e4 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -32,13 +32,13 @@ type TokenUsage struct { Completion int } -func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { +func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) { modelFile := c.Model var inferenceModel grpc.Backend var err error - opts := ModelOptions(c, o, []model.Option{}) + opts := ModelOptions(*c, o, []model.Option{}) if c.Backend != "" { opts = append(opts, model.WithBackendString(c.Backend)) @@ -96,7 +96,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported fn := func() (LLMResponse, error) { - opts := gRPCPredictOpts(c, loader.ModelPath) + opts := gRPCPredictOpts(*c, loader.ModelPath) opts.Prompt = s opts.Messages = protoMessages opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate diff --git a/core/backend/rerank.go b/core/backend/rerank.go index f600e2e6eaf..f665c4c64fd 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -9,9 +9,9 @@ import ( model "github.com/mudler/LocalAI/pkg/model" ) -func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { +func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) + opts := ModelOptions(backendConfig, appConfig, []model.Option{}) rerankModel, err := loader.BackendLoader(opts...) if err != nil { return nil, err diff --git a/core/backend/soundgeneration.go b/core/backend/soundgeneration.go index b1b458b447a..66674ced556 100644 --- a/core/backend/soundgeneration.go +++ b/core/backend/soundgeneration.go @@ -13,7 +13,6 @@ import ( ) func SoundGeneration( - modelFile string, text string, duration *float32, temperature *float32, @@ -25,7 +24,7 @@ func SoundGeneration( backendConfig config.BackendConfig, ) (string, *proto.Result, error) { - opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)}) + opts := ModelOptions(backendConfig, appConfig, []model.Option{}) soundGenModel, err := loader.BackendLoader(opts...) if err != nil { @@ -45,7 +44,7 @@ func SoundGeneration( res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{ Text: text, - Model: modelFile, + Model: backendConfig.Model, Dst: filePath, Sample: doSample, Duration: duration, diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index c8ec8d1cb26..111712daf7f 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -9,14 +9,10 @@ import ( func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) { - modelFile := backendConfig.Model - var inferenceModel grpc.Backend var err error - opts := ModelOptions(backendConfig, appConfig, []model.Option{ - model.WithModel(modelFile), - }) + opts := ModelOptions(backendConfig, appConfig, []model.Option{}) if backendConfig.Backend == "" { inferenceModel, err = loader.GreedyLoader(opts...) diff --git a/core/backend/tts.go b/core/backend/tts.go index bac2e900883..799804a4480 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -14,23 +14,15 @@ import ( ) func ModelTTS( - backend, text, - modelFile, voice, language string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig, ) (string, *proto.Result, error) { - bb := backend - if bb == "" { - bb = model.PiperBackend - } - - opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{ - model.WithBackendString(bb), - model.WithModel(modelFile), + opts := ModelOptions(*&backendConfig, appConfig, []model.Option{ + model.WithDefaultBackendString(model.PiperBackend), }) ttsModel, err := loader.BackendLoader(opts...) if err != nil { @@ -38,7 +30,7 @@ func ModelTTS( } if ttsModel == nil { - return "", nil, fmt.Errorf("could not load piper model") + return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model) } if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil { @@ -48,22 +40,21 @@ func ModelTTS( fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav") filePath := filepath.Join(appConfig.AudioDir, fileName) - // If the model file is not empty, we pass it joined with the model path + // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect. + // This should be addressed in a follow up PR soon. + // Copying it over nearly verbatim, as TTS backends are not functional without this. modelPath := "" - if modelFile != "" { - // If the model file is not empty, we pass it joined with the model path - // Checking first that it exists and is not outside ModelPath - // TODO: we should actually first check if the modelFile is looking like - // a FS path - mp := filepath.Join(loader.ModelPath, modelFile) - if _, err := os.Stat(mp); err == nil { - if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil { - return "", nil, err - } - modelPath = mp - } else { - modelPath = modelFile + // Checking first that it exists and is not outside ModelPath + // TODO: we should actually first check if the modelFile is looking like + // a FS path + mp := filepath.Join(loader.ModelPath, backendConfig.Model) + if _, err := os.Stat(mp); err == nil { + if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil { + return "", nil, err } + modelPath = mp + } else { + modelPath = backendConfig.Model // skip this step if it fails????? } res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{ diff --git a/core/cli/soundgeneration.go b/core/cli/soundgeneration.go index 82bc03469e3..a8acd6baa90 100644 --- a/core/cli/soundgeneration.go +++ b/core/cli/soundgeneration.go @@ -86,13 +86,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { options := config.BackendConfig{} options.SetDefaults() options.Backend = t.Backend + options.Model = t.Model var inputFile *string if t.InputFile != "" { inputFile = &t.InputFile } - filePath, _, err := backend.SoundGeneration(t.Model, text, + filePath, _, err := backend.SoundGeneration(text, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options) diff --git a/core/cli/tts.go b/core/cli/tts.go index 150ca3d0f43..af51ce06964 100644 --- a/core/cli/tts.go +++ b/core/cli/tts.go @@ -52,8 +52,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error { options := config.BackendConfig{} options.SetDefaults() + options.Backend = t.Backend + options.Model = t.Model - filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options) + filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options) if err != nil { return err } diff --git a/core/config/backend_config.go b/core/config/backend_config.go index c3d1063dbd2..be5efc8dda1 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -432,19 +432,20 @@ func (c *BackendConfig) HasTemplate() bool { type BackendConfigUsecases int const ( - FLAG_ANY BackendConfigUsecases = 0b000000000 - FLAG_CHAT BackendConfigUsecases = 0b000000001 - FLAG_COMPLETION BackendConfigUsecases = 0b000000010 - FLAG_EDIT BackendConfigUsecases = 0b000000100 - FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000 - FLAG_RERANK BackendConfigUsecases = 0b000010000 - FLAG_IMAGE BackendConfigUsecases = 0b000100000 - FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000 - FLAG_TTS BackendConfigUsecases = 0b010000000 - FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000 + FLAG_ANY BackendConfigUsecases = 0b0000000000 + FLAG_CHAT BackendConfigUsecases = 0b0000000001 + FLAG_COMPLETION BackendConfigUsecases = 0b0000000010 + FLAG_EDIT BackendConfigUsecases = 0b0000000100 + FLAG_EMBEDDINGS BackendConfigUsecases = 0b0000001000 + FLAG_RERANK BackendConfigUsecases = 0b0000010000 + FLAG_IMAGE BackendConfigUsecases = 0b0000100000 + FLAG_TRANSCRIPT BackendConfigUsecases = 0b0001000000 + FLAG_TTS BackendConfigUsecases = 0b0010000000 + FLAG_SOUND_GENERATION BackendConfigUsecases = 0b0100000000 + FLAG_TOKENIZE BackendConfigUsecases = 0b1000000000 // Common Subsets - FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT + FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT ) func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { @@ -459,6 +460,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { "FLAG_TRANSCRIPT": FLAG_TRANSCRIPT, "FLAG_TTS": FLAG_TTS, "FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION, + "FLAG_TOKENIZE": FLAG_TOKENIZE, "FLAG_LLM": FLAG_LLM, } } @@ -544,5 +546,12 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool { } } + if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE { + tokenizeCapableBackends := []string{"llama.cpp", "rwkv"} + if !slices.Contains(tokenizeCapableBackends, c.Backend) { + return false + } + } + return true } diff --git a/core/config/backend_config_loader.go b/core/config/backend_config_loader.go index 7fe49bab322..d97e90d3c66 100644 --- a/core/config/backend_config_loader.go +++ b/core/config/backend_config_loader.go @@ -81,10 +81,10 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption) c := &[]*BackendConfig{} f, err := os.ReadFile(file) if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) + return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err) } if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err) } for _, cc := range *c { @@ -101,10 +101,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen c := &BackendConfig{} f, err := os.ReadFile(file) if err != nil { - return nil, fmt.Errorf("cannot read config file: %w", err) + return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err) } if err := yaml.Unmarshal(f, c); err != nil { - return nil, fmt.Errorf("cannot unmarshal config file: %w", err) + return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err) } c.SetDefaults(opts...) @@ -117,7 +117,9 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath // Load a config file if present after the model name cfg := &BackendConfig{ PredictionOptions: schema.PredictionOptions{ - Model: modelName, + BasicModelRequest: schema.BasicModelRequest{ + Model: modelName, + }, }, } @@ -145,6 +147,15 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath return cfg, nil } +func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) { + return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath, + LoadOptionDebug(appConfig.Debug), + LoadOptionThreads(appConfig.Threads), + LoadOptionContextSize(appConfig.ContextSize), + LoadOptionF16(appConfig.F16), + ModelPath(appConfig.ModelPath)) +} + // This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error { bcl.Lock() @@ -167,7 +178,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa defer bcl.Unlock() c, err := readBackendConfigFromFile(file, opts...) if err != nil { - return fmt.Errorf("cannot read config file: %w", err) + return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err) } if c.Validate() { @@ -324,9 +335,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error { func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error { bcl.Lock() defer bcl.Unlock() + entries, err := os.ReadDir(path) if err != nil { - return fmt.Errorf("cannot read directory '%s': %w", path, err) + return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err) } files := make([]fs.FileInfo, 0, len(entries)) for _, entry := range entries { @@ -344,13 +356,13 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ... } c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...) if err != nil { - log.Error().Err(err).Msgf("cannot read config file: %s", file.Name()) + log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file") continue } if c.Validate() { bcl.configs[c.Name] = *c } else { - log.Error().Err(err).Msgf("config is not valid") + log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid") } } diff --git a/core/config/guesser.go b/core/config/guesser.go index b63dd051a32..3dea311ffe2 100644 --- a/core/config/guesser.go +++ b/core/config/guesser.go @@ -26,14 +26,14 @@ const ( type settingsConfig struct { StopWords []string TemplateConfig TemplateConfig - RepeatPenalty float64 + RepeatPenalty float64 } // default settings to adopt with a given model family var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{ Gemma: { RepeatPenalty: 1.0, - StopWords: []string{"<|im_end|>", "", ""}, + StopWords: []string{"<|im_end|>", "", ""}, TemplateConfig: TemplateConfig{ Chat: "{{.Input }}\nmodel\n", ChatMessage: "{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}", @@ -161,10 +161,11 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) { } // We try to guess only if we don't have a template defined already - f, err := gguf.ParseGGUFFile(filepath.Join(modelPath, cfg.ModelFileName())) + guessPath := filepath.Join(modelPath, cfg.ModelFileName()) + f, err := gguf.ParseGGUFFile(guessPath) if err != nil { // Only valid for gguf files - log.Debug().Msgf("guessDefaultsFromFile: %s", "not a GGUF file") + log.Debug().Str("filePath", guessPath).Msg("guessDefaultsFromFile: not a GGUF file") return } diff --git a/core/http/app.go b/core/http/app.go index 2ba2c2b9953..e7d4e6bbc4c 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -121,7 +121,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi return metricsService.Shutdown() }) } - } // Health Checks should always be exempt from auth, so register these first routes.HealthRoutes(app) @@ -158,13 +157,15 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi galleryService := services.NewGalleryService(appConfig) galleryService.Start(appConfig.Context, cl) - routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig) - routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService) - routes.RegisterOpenAIRoutes(app, cl, ml, appConfig) + requestExtractor := middleware.NewRequestExtractor(cl, ml, appConfig) + + routes.RegisterElevenLabsRoutes(app, requestExtractor, cl, ml, appConfig) + routes.RegisterLocalAIRoutes(app, requestExtractor, cl, ml, appConfig, galleryService) + routes.RegisterOpenAIRoutes(app, requestExtractor, cl, ml, appConfig) if !appConfig.DisableWebUI { routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService) } - routes.RegisterJINARoutes(app, cl, ml, appConfig) + routes.RegisterJINARoutes(app, requestExtractor, cl, ml, appConfig) httpFS := http.FS(embedDirStatic) diff --git a/core/http/ctx/fiber.go b/core/http/ctx/fiber.go deleted file mode 100644 index 254f070400b..00000000000 --- a/core/http/ctx/fiber.go +++ /dev/null @@ -1,47 +0,0 @@ -package fiberContext - -import ( - "fmt" - "strings" - - "github.com/gofiber/fiber/v2" - "github.com/mudler/LocalAI/core/config" - "github.com/mudler/LocalAI/core/services" - "github.com/mudler/LocalAI/pkg/model" - "github.com/rs/zerolog/log" -) - -// ModelFromContext returns the model from the context -// If no model is specified, it will take the first available -// Takes a model string as input which should be the one received from the user request. -// It returns the model name resolved from the context and an error if any. -func ModelFromContext(ctx *fiber.Ctx, cl *config.BackendConfigLoader, loader *model.ModelLoader, modelInput string, firstModel bool) (string, error) { - if ctx.Params("model") != "" { - modelInput = ctx.Params("model") - } - if ctx.Query("model") != "" { - modelInput = ctx.Query("model") - } - // Set model from bearer token, if available - bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // Reduced duplicate characters of Bearer - bearerExists := bearer != "" && loader.ExistsInModelPath(bearer) - - // If no model was specified, take the first available - if modelInput == "" && !bearerExists && firstModel { - models, _ := services.ListModels(cl, loader, config.NoFilterFn, services.SKIP_IF_CONFIGURED) - if len(models) > 0 { - modelInput = models[0] - log.Debug().Msgf("No model specified, using: %s", modelInput) - } else { - log.Debug().Msgf("No model specified, returning error") - return "", fmt.Errorf("no model specified") - } - } - - // If a model is found in bearer token takes precedence - if bearerExists { - log.Debug().Msgf("Using model from bearer token: %s", bearer) - modelInput = bearer - } - return modelInput, nil -} diff --git a/core/http/endpoints/elevenlabs/soundgeneration.go b/core/http/endpoints/elevenlabs/soundgeneration.go index 345df35b8a2..548716def74 100644 --- a/core/http/endpoints/elevenlabs/soundgeneration.go +++ b/core/http/endpoints/elevenlabs/soundgeneration.go @@ -4,7 +4,7 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" - fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" @@ -17,45 +17,21 @@ import ( // @Router /v1/sound-generation [post] func SoundGenerationEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(schema.ElevenLabsSoundGenerationRequest) - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) - if err != nil { - modelFile = input.ModelID - log.Warn().Str("ModelID", input.ModelID).Msg("Model not found in context") + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest) + if !ok || input.ModelID == "" { + return fiber.ErrBadRequest } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) - if err != nil { - modelFile = input.ModelID - log.Warn().Str("Request ModelID", input.ModelID).Err(err).Msg("error during LoadBackendConfigFileByName, using request ModelID") - } else { - if input.ModelID != "" { - modelFile = input.ModelID - } else { - modelFile = cfg.Model - } + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || cfg == nil { + return fiber.ErrBadRequest } - log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend") - if input.Duration != nil { - log.Debug().Float32("duration", *input.Duration).Msg("duration set") - } - if input.Temperature != nil { - log.Debug().Float32("temperature", *input.Temperature).Msg("temperature set") - } + log.Debug().Str("modelFile", "modelFile").Str("backend", cfg.Backend).Msg("Sound Generation Request about to be sent to backend") // TODO: Support uploading files? - filePath, _, err := backend.SoundGeneration(modelFile, input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) + filePath, _, err := backend.SoundGeneration(input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/elevenlabs/tts.go b/core/http/endpoints/elevenlabs/tts.go index bb6901be887..4845887014f 100644 --- a/core/http/endpoints/elevenlabs/tts.go +++ b/core/http/endpoints/elevenlabs/tts.go @@ -3,7 +3,7 @@ package elevenlabs import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" - fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -20,39 +20,21 @@ import ( func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(schema.ElevenLabsTTSRequest) voiceID := c.Params("voice-id") - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest) + if !ok || input.ModelID == "" { + return fiber.ErrBadRequest } - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.ModelID, false) - if err != nil { - modelFile = input.ModelID - log.Warn().Msgf("Model not found in context: %s", input.ModelID) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || cfg == nil { + return fiber.ErrBadRequest } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) - if err != nil { - modelFile = input.ModelID - log.Warn().Msgf("Model not found in context: %s", input.ModelID) - } else { - if input.ModelID != "" { - modelFile = input.ModelID - } else { - modelFile = cfg.Model - } - } - log.Debug().Msgf("Request for model: %s", modelFile) + log.Debug().Str("modelName", input.ModelID).Msg("elevenlabs TTS request recieved") - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Text, modelFile, "", voiceID, ml, appConfig, *cfg) + filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/jina/rerank.go b/core/http/endpoints/jina/rerank.go index 58c3972d655..eb2d191156f 100644 --- a/core/http/endpoints/jina/rerank.go +++ b/core/http/endpoints/jina/rerank.go @@ -3,9 +3,9 @@ package jina import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/gofiber/fiber/v2" - fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" @@ -19,58 +19,32 @@ import ( // @Router /v1/rerank [post] func JINARerankEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - req := new(schema.JINARerankRequest) - if err := c.BodyParser(req); err != nil { - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ - "error": "Cannot parse JSON", - }) - } - - input := new(schema.TTSRequest) - - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err - } - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) - if err != nil { - modelFile = input.Model - log.Warn().Msgf("Model not found in context: %s", input.Model) + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) - if err != nil { - modelFile = input.Model - log.Warn().Msgf("Model not found in context: %s", input.Model) - } else { - modelFile = cfg.Model + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || cfg == nil { + return fiber.ErrBadRequest } - log.Debug().Msgf("Request for model: %s", modelFile) - - if input.Backend != "" { - cfg.Backend = input.Backend - } + log.Debug().Str("model", input.Model).Msg("JINA Rerank Request recieved") request := &proto.RerankRequest{ - Query: req.Query, - TopN: int32(req.TopN), - Documents: req.Documents, + Query: input.Query, + TopN: int32(input.TopN), + Documents: input.Documents, } - results, err := backend.Rerank(modelFile, request, ml, appConfig, *cfg) + results, err := backend.Rerank(request, ml, appConfig, *cfg) if err != nil { return err } response := &schema.JINARerankResponse{ - Model: req.Model, + Model: input.Model, } for _, r := range results.Results { diff --git a/core/http/endpoints/localai/get_token_metrics.go b/core/http/endpoints/localai/get_token_metrics.go index e0e6943f129..30de2cdd5f2 100644 --- a/core/http/endpoints/localai/get_token_metrics.go +++ b/core/http/endpoints/localai/get_token_metrics.go @@ -4,13 +4,15 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" - fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/rs/zerolog/log" "github.com/mudler/LocalAI/pkg/model" ) +// TODO: This is not yet in use. Needs middleware rework, since it is not referenced. + // TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID // // @Summary Get TokenMetrics for Active Slot. @@ -29,18 +31,13 @@ func TokenMetricsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, return err } - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) - if err != nil { + modelFile, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if !ok || modelFile != "" { modelFile = input.Model log.Warn().Msgf("Model not found in context: %s", input.Model) } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) + cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(modelFile, appConfig) if err != nil { log.Err(err) diff --git a/core/http/endpoints/localai/tokenize.go b/core/http/endpoints/localai/tokenize.go index da110bf864e..14da005ae4a 100644 --- a/core/http/endpoints/localai/tokenize.go +++ b/core/http/endpoints/localai/tokenize.go @@ -4,10 +4,9 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" - fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" - "github.com/rs/zerolog/log" ) // TokenizeEndpoint exposes a REST API to tokenize the content @@ -15,44 +14,22 @@ import ( // @Success 200 {object} schema.TokenizeResponse "Response" // @Router /v1/tokenize [post] func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - return func(c *fiber.Ctx) error { - - input := new(schema.TokenizeRequest) - - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err + return func(ctx *fiber.Ctx) error { + input, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest } - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) - if err != nil { - modelFile = input.Model - log.Warn().Msgf("Model not found in context: %s", input.Model) + cfg, ok := ctx.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || cfg == nil { + return fiber.ErrBadRequest } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) - - if err != nil { - log.Err(err) - modelFile = input.Model - log.Warn().Msgf("Model not found in context: %s", input.Model) - } else { - modelFile = cfg.Model - } - log.Debug().Msgf("Request for model: %s", modelFile) - tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig) if err != nil { return err } - c.JSON(tokenResponse) - return nil - + return ctx.JSON(tokenResponse) } } diff --git a/core/http/endpoints/localai/tts.go b/core/http/endpoints/localai/tts.go index ca3f58bd9e2..7f8b9aaa854 100644 --- a/core/http/endpoints/localai/tts.go +++ b/core/http/endpoints/localai/tts.go @@ -3,7 +3,7 @@ package localai import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" - fiberContext "github.com/mudler/LocalAI/core/http/ctx" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -12,47 +12,35 @@ import ( ) // TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech -// @Summary Generates audio from the input text. -// @Accept json -// @Produce audio/x-wav -// @Param request body schema.TTSRequest true "query params" -// @Success 200 {string} binary "generated audio/wav file" -// @Router /v1/audio/speech [post] -// @Router /tts [post] +// +// @Summary Generates audio from the input text. +// @Accept json +// @Produce audio/x-wav +// @Param request body schema.TTSRequest true "query params" +// @Success 200 {string} binary "generated audio/wav file" +// @Router /v1/audio/speech [post] +// @Router /tts [post] func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - input := new(schema.TTSRequest) - - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return err + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest } - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, false) - if err != nil { - modelFile = input.Model - log.Warn().Msgf("Model not found in context: %s", input.Model) + cfg, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || cfg == nil { + return fiber.ErrBadRequest } - cfg, err := cl.LoadBackendConfigFileByName(modelFile, appConfig.ModelPath, - config.LoadOptionDebug(appConfig.Debug), - config.LoadOptionThreads(appConfig.Threads), - config.LoadOptionContextSize(appConfig.ContextSize), - config.LoadOptionF16(appConfig.F16), - ) - - if err != nil { - log.Err(err) - modelFile = input.Model - log.Warn().Msgf("Model not found in context: %s", input.Model) - } else { - modelFile = cfg.Model - } - log.Debug().Msgf("Request for model: %s", modelFile) + log.Debug().Str("model", input.Model).Msg("LocalAI TTS Request recieved") - if input.Backend != "" { - cfg.Backend = input.Backend + if cfg.Backend == "" { + if input.Backend != "" { + cfg.Backend = input.Backend + } else { + cfg.Backend = model.PiperBackend + } } if input.Language != "" { @@ -63,7 +51,7 @@ func TTSEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfi cfg.Voice = input.Voice } - filePath, _, err := backend.ModelTTS(cfg.Backend, input.Input, modelFile, cfg.Voice, cfg.Language, ml, appConfig, *cfg) + filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg) if err != nil { return err } diff --git a/core/http/endpoints/openai/chat.go b/core/http/endpoints/openai/chat.go index 1ac1387eed3..8657a239077 100644 --- a/core/http/endpoints/openai/chat.go +++ b/core/http/endpoints/openai/chat.go @@ -12,9 +12,10 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" - model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" "github.com/valyala/fasthttp" ) @@ -161,23 +162,18 @@ func ChatEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, startup textContentToReturn = "" id = uuid.New().String() created = int(time.Now().Unix()) - // Set CorrelationID - correlationID := c.Get("X-Correlation-ID") - if len(strings.TrimSpace(correlationID)) == 0 { - correlationID = id - } - c.Set("X-Correlation-ID", correlationID) - modelFile, input, err := readRequest(c, cl, ml, startupOptions, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest } - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, startupOptions.Debug, startupOptions.Threads, startupOptions.ContextSize, startupOptions.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || config == nil { + return fiber.ErrBadRequest } - log.Debug().Msgf("Configuration read: %+v", config) + + log.Debug().Msgf("Chat endpoint configuration read: %+v", config) funcs := input.Functions shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() @@ -656,7 +652,7 @@ func handleQuestion(config *config.BackendConfig, input *schema.OpenAIRequest, m audios = append(audios, m.StringAudios...) } - predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, *config, o, nil) + predFunc, err := backend.ModelInference(input.Context, prompt, input.Messages, images, videos, audios, ml, config, o, nil) if err != nil { log.Error().Err(err).Msg("model inference failed") return "", err diff --git a/core/http/endpoints/openai/completion.go b/core/http/endpoints/openai/completion.go index e5de1b3f029..3cbf9d18647 100644 --- a/core/http/endpoints/openai/completion.go +++ b/core/http/endpoints/openai/completion.go @@ -10,6 +10,7 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/gofiber/fiber/v2" "github.com/google/uuid" @@ -26,10 +27,9 @@ import ( // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { - id := uuid.New().String() created := int(time.Now().Unix()) - process := func(s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { + process := func(id string, s string, req *schema.OpenAIRequest, config *config.BackendConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse) { ComputeChoices(req, s, config, appConfig, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { resp := schema.OpenAIResponse{ ID: id, @@ -57,18 +57,17 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a } return func(c *fiber.Ctx) error { - // Add Correlation - c.Set("X-Correlation-ID", id) - modelFile, input, err := readRequest(c, cl, ml, appConfig, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } + // Handle Correlation + id := c.Get("X-Correlation-ID", uuid.New().String()) - log.Debug().Msgf("`input`: %+v", input) + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest + } - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || config == nil { + return fiber.ErrBadRequest } if config.ResponseFormatMap != nil { @@ -125,7 +124,7 @@ func CompletionEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a responses := make(chan schema.OpenAIResponse) - go process(predInput, input, config, ml, responses) + go process(id, predInput, input, config, ml, responses) c.Context().SetBodyStreamWriter(fasthttp.StreamWriter(func(w *bufio.Writer) { diff --git a/core/http/endpoints/openai/edit.go b/core/http/endpoints/openai/edit.go index 12fb4035255..58cc06b8d66 100644 --- a/core/http/endpoints/openai/edit.go +++ b/core/http/endpoints/openai/edit.go @@ -7,11 +7,12 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" - model "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/model" "github.com/rs/zerolog/log" ) @@ -23,17 +24,18 @@ import ( // @Router /v1/edits [post] func EditEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - modelFile, input, err := readRequest(c, cl, ml, appConfig, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest } - config, input, err := mergeRequestWithConfig(modelFile, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || config == nil { + return fiber.ErrBadRequest } - log.Debug().Msgf("Parameter Config: %+v", config) + log.Debug().Msgf("Edit Endpoint Input : %+v", input) + log.Debug().Msgf("Edit Endpoint Config: %+v", *config) templateFile := "" diff --git a/core/http/endpoints/openai/embeddings.go b/core/http/endpoints/openai/embeddings.go index e247d84e332..9cbbe189457 100644 --- a/core/http/endpoints/openai/embeddings.go +++ b/core/http/endpoints/openai/embeddings.go @@ -2,11 +2,11 @@ package openai import ( "encoding/json" - "fmt" "time" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/pkg/model" "github.com/google/uuid" @@ -23,14 +23,14 @@ import ( // @Router /v1/embeddings [post] func EmbeddingsEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - model, input, err := readRequest(c, cl, ml, appConfig, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest } - config, input, err := mergeRequestWithConfig(model, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || config == nil { + return fiber.ErrBadRequest } log.Debug().Msgf("Parameter Config: %+v", config) diff --git a/core/http/endpoints/openai/image.go b/core/http/endpoints/openai/image.go index 6c76ba84327..3e89ee0e071 100644 --- a/core/http/endpoints/openai/image.go +++ b/core/http/endpoints/openai/image.go @@ -15,6 +15,7 @@ import ( "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/backend" @@ -66,25 +67,23 @@ func downloadFile(url string) (string, error) { // @Router /v1/images/generations [post] func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, cl, ml, appConfig, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } - - if m == "" { - m = model.StableDiffusionBackend + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + log.Error().Msg("Image Endpoint - Invalid Input") + return fiber.ErrBadRequest } - log.Debug().Msgf("Loading model: %+v", m) - config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || config == nil { + log.Error().Msg("Image Endpoint - Invalid Config") + return fiber.ErrBadRequest } src := "" if input.File != "" { fileData := []byte{} + var err error // check if input.File is an URL, if so download it and save it // to a temporary file if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") { diff --git a/core/http/endpoints/openai/inference.go b/core/http/endpoints/openai/inference.go index da75d3a1ea5..fd85bb74d1d 100644 --- a/core/http/endpoints/openai/inference.go +++ b/core/http/endpoints/openai/inference.go @@ -37,7 +37,7 @@ func ComputeChoices( } // get the model function to call for the result - predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, *config, o, tokenCallback) + predFunc, err := backend.ModelInference(req.Context, predInput, req.Messages, images, videos, audios, loader, config, o, tokenCallback) if err != nil { return result, backend.TokenUsage{}, err } diff --git a/core/http/endpoints/openai/transcription.go b/core/http/endpoints/openai/transcription.go index 4e23f8046c6..b10e06ef696 100644 --- a/core/http/endpoints/openai/transcription.go +++ b/core/http/endpoints/openai/transcription.go @@ -1,7 +1,6 @@ package openai import ( - "fmt" "io" "net/http" "os" @@ -10,6 +9,8 @@ import ( "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" model "github.com/mudler/LocalAI/pkg/model" "github.com/gofiber/fiber/v2" @@ -25,15 +26,16 @@ import ( // @Router /v1/audio/transcriptions [post] func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error { - m, input, err := readRequest(c, cl, ml, appConfig, false) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) + input, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest } - config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, appConfig.Threads, appConfig.ContextSize, appConfig.F16) - if err != nil { - return fmt.Errorf("failed reading parameters from request: %w", err) + config, ok := c.Locals(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || config == nil { + return fiber.ErrBadRequest } + // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { diff --git a/core/http/endpoints/openai/request.go b/core/http/middleware/request.go similarity index 50% rename from core/http/endpoints/openai/request.go rename to core/http/middleware/request.go index 1309fa820e0..1372951e68e 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/middleware/request.go @@ -1,318 +1,441 @@ -package openai - -import ( - "context" - "encoding/json" - "fmt" - - "github.com/gofiber/fiber/v2" - "github.com/google/uuid" - "github.com/mudler/LocalAI/core/config" - fiberContext "github.com/mudler/LocalAI/core/http/ctx" - "github.com/mudler/LocalAI/core/schema" - "github.com/mudler/LocalAI/pkg/functions" - "github.com/mudler/LocalAI/pkg/model" - "github.com/mudler/LocalAI/pkg/templates" - "github.com/mudler/LocalAI/pkg/utils" - "github.com/rs/zerolog/log" -) - -type correlationIDKeyType string - -// CorrelationIDKey to track request across process boundary -const CorrelationIDKey correlationIDKeyType = "correlationID" - -func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { - input := new(schema.OpenAIRequest) - - // Get input data from the request body - if err := c.BodyParser(input); err != nil { - return "", nil, fmt.Errorf("failed parsing request body: %w", err) - } - - received, _ := json.Marshal(input) - // Extract or generate the correlation ID - correlationID := c.Get("X-Correlation-ID", uuid.New().String()) - - ctx, cancel := context.WithCancel(o.Context) - // Add the correlation ID to the new context - ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID) - - input.Context = ctxWithCorrelationID - input.Cancel = cancel - - log.Debug().Msgf("Request received: %s", string(received)) - - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) - - return modelFile, input, err -} - -func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { - if input.Echo { - config.Echo = input.Echo - } - if input.TopK != nil { - config.TopK = input.TopK - } - if input.TopP != nil { - config.TopP = input.TopP - } - - if input.Backend != "" { - config.Backend = input.Backend - } - - if input.ClipSkip != 0 { - config.Diffusers.ClipSkip = input.ClipSkip - } - - if input.ModelBaseName != "" { - config.AutoGPTQ.ModelBaseName = input.ModelBaseName - } - - if input.NegativePromptScale != 0 { - config.NegativePromptScale = input.NegativePromptScale - } - - if input.UseFastTokenizer { - config.UseFastTokenizer = input.UseFastTokenizer - } - - if input.NegativePrompt != "" { - config.NegativePrompt = input.NegativePrompt - } - - if input.RopeFreqBase != 0 { - config.RopeFreqBase = input.RopeFreqBase - } - - if input.RopeFreqScale != 0 { - config.RopeFreqScale = input.RopeFreqScale - } - - if input.Grammar != "" { - config.Grammar = input.Grammar - } - - if input.Temperature != nil { - config.Temperature = input.Temperature - } - - if input.Maxtokens != nil { - config.Maxtokens = input.Maxtokens - } - - if input.ResponseFormat != nil { - switch responseFormat := input.ResponseFormat.(type) { - case string: - config.ResponseFormat = responseFormat - case map[string]interface{}: - config.ResponseFormatMap = responseFormat - } - } - - switch stop := input.Stop.(type) { - case string: - if stop != "" { - config.StopWords = append(config.StopWords, stop) - } - case []interface{}: - for _, pp := range stop { - if s, ok := pp.(string); ok { - config.StopWords = append(config.StopWords, s) - } - } - } - - if len(input.Tools) > 0 { - for _, tool := range input.Tools { - input.Functions = append(input.Functions, tool.Function) - } - } - - if input.ToolsChoice != nil { - var toolChoice functions.Tool - - switch content := input.ToolsChoice.(type) { - case string: - _ = json.Unmarshal([]byte(content), &toolChoice) - case map[string]interface{}: - dat, _ := json.Marshal(content) - _ = json.Unmarshal(dat, &toolChoice) - } - input.FunctionCall = map[string]interface{}{ - "name": toolChoice.Function.Name, - } - } - - // Decode each request's message content - imgIndex, vidIndex, audioIndex := 0, 0, 0 - for i, m := range input.Messages { - nrOfImgsInMessage := 0 - nrOfVideosInMessage := 0 - nrOfAudiosInMessage := 0 - - switch content := m.Content.(type) { - case string: - input.Messages[i].StringContent = content - case []interface{}: - dat, _ := json.Marshal(content) - c := []schema.Content{} - json.Unmarshal(dat, &c) - - textContent := "" - // we will template this at the end - - CONTENT: - for _, pp := range c { - switch pp.Type { - case "text": - textContent += pp.Text - //input.Messages[i].StringContent = pp.Text - case "video", "video_url": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) - if err != nil { - log.Error().Msgf("Failed encoding video: %s", err) - continue CONTENT - } - input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff - vidIndex++ - nrOfVideosInMessage++ - case "audio_url", "audio": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) - if err != nil { - log.Error().Msgf("Failed encoding image: %s", err) - continue CONTENT - } - input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff - audioIndex++ - nrOfAudiosInMessage++ - case "image_url", "image": - // Decode content as base64 either if it's an URL or base64 text - base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) - if err != nil { - log.Error().Msgf("Failed encoding image: %s", err) - continue CONTENT - } - - input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff - - imgIndex++ - nrOfImgsInMessage++ - } - } - - input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ - TotalImages: imgIndex, - TotalVideos: vidIndex, - TotalAudios: audioIndex, - ImagesInMessage: nrOfImgsInMessage, - VideosInMessage: nrOfVideosInMessage, - AudiosInMessage: nrOfAudiosInMessage, - }, textContent) - } - } - - if input.RepeatPenalty != 0 { - config.RepeatPenalty = input.RepeatPenalty - } - - if input.FrequencyPenalty != 0 { - config.FrequencyPenalty = input.FrequencyPenalty - } - - if input.PresencePenalty != 0 { - config.PresencePenalty = input.PresencePenalty - } - - if input.Keep != 0 { - config.Keep = input.Keep - } - - if input.Batch != 0 { - config.Batch = input.Batch - } - - if input.IgnoreEOS { - config.IgnoreEOS = input.IgnoreEOS - } - - if input.Seed != nil { - config.Seed = input.Seed - } - - if input.TypicalP != nil { - config.TypicalP = input.TypicalP - } - - switch inputs := input.Input.(type) { - case string: - if inputs != "" { - config.InputStrings = append(config.InputStrings, inputs) - } - case []interface{}: - for _, pp := range inputs { - switch i := pp.(type) { - case string: - config.InputStrings = append(config.InputStrings, i) - case []interface{}: - tokens := []int{} - for _, ii := range i { - tokens = append(tokens, int(ii.(float64))) - } - config.InputToken = append(config.InputToken, tokens) - } - } - } - - // Can be either a string or an object - switch fnc := input.FunctionCall.(type) { - case string: - if fnc != "" { - config.SetFunctionCallString(fnc) - } - case map[string]interface{}: - var name string - n, exists := fnc["name"] - if exists { - nn, e := n.(string) - if e { - name = nn - } - } - config.SetFunctionCallNameString(name) - } - - switch p := input.Prompt.(type) { - case string: - config.PromptStrings = append(config.PromptStrings, p) - case []interface{}: - for _, pp := range p { - if s, ok := pp.(string); ok { - config.PromptStrings = append(config.PromptStrings, s) - } - } - } -} - -func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { - cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, - config.LoadOptionDebug(debug), - config.LoadOptionThreads(threads), - config.LoadOptionContextSize(ctx), - config.LoadOptionF16(f16), - config.ModelPath(loader.ModelPath), - ) - - // Set the parameters for the language model prediction - updateRequestConfig(cfg, input) - - if !cfg.Validate() { - return nil, nil, fmt.Errorf("failed to validate config") - } - - return cfg, input, err -} +package middleware + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/google/uuid" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services" + "github.com/mudler/LocalAI/pkg/functions" + "github.com/mudler/LocalAI/pkg/model" + "github.com/mudler/LocalAI/pkg/templates" + "github.com/mudler/LocalAI/pkg/utils" + + "github.com/gofiber/fiber/v2" + "github.com/rs/zerolog/log" +) + +type correlationIDKeyType string + +// CorrelationIDKey to track request across process boundary +const CorrelationIDKey correlationIDKeyType = "correlationID" + +type RequestExtractor struct { + backendConfigLoader *config.BackendConfigLoader + modelLoader *model.ModelLoader + applicationConfig *config.ApplicationConfig +} + +func NewRequestExtractor(backendConfigLoader *config.BackendConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { + return &RequestExtractor{ + backendConfigLoader: backendConfigLoader, + modelLoader: modelLoader, + applicationConfig: applicationConfig, + } +} + +const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" +const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" +const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" + +// TODO: Refactor to not return error if unchanged +func (re *RequestExtractor) setModelNameFromRequest(ctx *fiber.Ctx) { + model, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && model != "" { + return + } + model = ctx.Params("model") + + if (model == "") && ctx.Query("model") != "" { + model = ctx.Query("model") + } + + if model == "" { + // Set model from bearer token, if available + bearer := strings.TrimLeft(ctx.Get("authorization"), "Bear ") // "Bearer " => "Bear" to please go-staticcheck. It looks dumb but we might as well take free performance on something called for nearly every request. + if bearer != "" { + exists, err := services.CheckIfModelExists(re.backendConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) + if err == nil && exists { + model = bearer + } + } + } + + ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, model) +} + +func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) fiber.Handler { + return func(ctx *fiber.Ctx) error { + re.setModelNameFromRequest(ctx) + localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if !ok || localModelName == "" { + ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) + log.Debug().Str("defaultModelName", defaultModelName).Msg("context local model name not found, setting to default") + } + return ctx.Next() + } +} + +func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.BackendConfigFilterFn) fiber.Handler { + return func(ctx *fiber.Ctx) error { + re.setModelNameFromRequest(ctx) + localModelName := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if localModelName != "" { // Don't overwrite existing values + return ctx.Next() + } + + modelNames, err := services.ListModels(re.backendConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) + if err != nil { + log.Error().Err(err).Msg("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()") + return ctx.Next() + } + + if len(modelNames) == 0 { + log.Warn().Msg("SetDefaultModelNameToFirstAvailable used with no matching models installed") + // This is non-fatal - making it so was breaking the case of direct installation of raw models + // return errors.New("this endpoint requires at least one model to be installed") + return ctx.Next() + } + + ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) + log.Debug().Str("first model name", modelNames[0]).Msg("context local model name not found, setting to the first model") + return ctx.Next() + } +} + +// TODO: If context and cancel above belong on all methods, move that part of above into here! +// Otherwise, it's in its own method below for now +func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) fiber.Handler { + return func(ctx *fiber.Ctx) error { + input := initializer() + if input == nil { + return fmt.Errorf("unable to initialize body") + } + if err := ctx.BodyParser(input); err != nil { + return fmt.Errorf("failed parsing request body: %w", err) + } + + // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain + if input.ModelName(nil) == "" { + localModelName, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) + if ok && localModelName != "" { + log.Debug().Str("context localModelName", localModelName).Msg("overriding empty model name in request body with value found earlier in middleware chain") + input.ModelName(&localModelName) + } + } + + cfg, err := re.backendConfigLoader.LoadBackendConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) + + if err != nil { + log.Err(err) + log.Warn().Msgf("Model Configuration File not found for %q", input.ModelName(nil)) + } else if cfg.Model == "" && input.ModelName(nil) != "" { + log.Debug().Str("input.ModelName", input.ModelName(nil)).Msg("config does not include model, using input") + cfg.Model = input.ModelName(nil) + } + + ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return ctx.Next() + } +} + +func (re *RequestExtractor) SetOpenAIRequest(ctx *fiber.Ctx) error { + input, ok := ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) + if !ok || input.Model == "" { + return fiber.ErrBadRequest + } + + cfg, ok := ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.BackendConfig) + if !ok || cfg == nil { + return fiber.ErrBadRequest + } + + // Extract or generate the correlation ID + correlationID := ctx.Get("X-Correlation-ID", uuid.New().String()) + ctx.Set("X-Correlation-ID", correlationID) + + c1, cancel := context.WithCancel(re.applicationConfig.Context) + // Add the correlation ID to the new context + context := context.WithValue(c1, CorrelationIDKey, correlationID) + + input.Context = context + input.Cancel = cancel + + err := mergeOpenAIRequestAndBackendConfig(cfg, input) + if err != nil { + return err + } + + if cfg.Model == "" { + log.Debug().Str("input.Model", input.Model).Msg("replacing empty cfg.Model with input value") + cfg.Model = input.Model + } + + ctx.Locals(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) + ctx.Locals(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) + + return ctx.Next() +} + +func mergeOpenAIRequestAndBackendConfig(config *config.BackendConfig, input *schema.OpenAIRequest) error { + if input.Echo { + config.Echo = input.Echo + } + if input.TopK != nil { + config.TopK = input.TopK + } + if input.TopP != nil { + config.TopP = input.TopP + } + + if input.Backend != "" { + config.Backend = input.Backend + } + + if input.ClipSkip != 0 { + config.Diffusers.ClipSkip = input.ClipSkip + } + + if input.ModelBaseName != "" { + config.AutoGPTQ.ModelBaseName = input.ModelBaseName + } + + if input.NegativePromptScale != 0 { + config.NegativePromptScale = input.NegativePromptScale + } + + if input.UseFastTokenizer { + config.UseFastTokenizer = input.UseFastTokenizer + } + + if input.NegativePrompt != "" { + config.NegativePrompt = input.NegativePrompt + } + + if input.RopeFreqBase != 0 { + config.RopeFreqBase = input.RopeFreqBase + } + + if input.RopeFreqScale != 0 { + config.RopeFreqScale = input.RopeFreqScale + } + + if input.Grammar != "" { + config.Grammar = input.Grammar + } + + if input.Temperature != nil { + config.Temperature = input.Temperature + } + + if input.Maxtokens != nil { + config.Maxtokens = input.Maxtokens + } + + if input.ResponseFormat != nil { + switch responseFormat := input.ResponseFormat.(type) { + case string: + config.ResponseFormat = responseFormat + case map[string]interface{}: + config.ResponseFormatMap = responseFormat + } + } + + switch stop := input.Stop.(type) { + case string: + if stop != "" { + config.StopWords = append(config.StopWords, stop) + } + case []interface{}: + for _, pp := range stop { + if s, ok := pp.(string); ok { + config.StopWords = append(config.StopWords, s) + } + } + } + + if len(input.Tools) > 0 { + for _, tool := range input.Tools { + input.Functions = append(input.Functions, tool.Function) + } + } + + if input.ToolsChoice != nil { + var toolChoice functions.Tool + + switch content := input.ToolsChoice.(type) { + case string: + _ = json.Unmarshal([]byte(content), &toolChoice) + case map[string]interface{}: + dat, _ := json.Marshal(content) + _ = json.Unmarshal(dat, &toolChoice) + } + input.FunctionCall = map[string]interface{}{ + "name": toolChoice.Function.Name, + } + } + + // Decode each request's message content + imgIndex, vidIndex, audioIndex := 0, 0, 0 + for i, m := range input.Messages { + nrOfImgsInMessage := 0 + nrOfVideosInMessage := 0 + nrOfAudiosInMessage := 0 + + switch content := m.Content.(type) { + case string: + input.Messages[i].StringContent = content + case []interface{}: + dat, _ := json.Marshal(content) + c := []schema.Content{} + json.Unmarshal(dat, &c) + + textContent := "" + // we will template this at the end + + CONTENT: + for _, pp := range c { + switch pp.Type { + case "text": + textContent += pp.Text + //input.Messages[i].StringContent = pp.Text + case "video", "video_url": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) + if err != nil { + log.Error().Msgf("Failed encoding video: %s", err) + continue CONTENT + } + input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff + vidIndex++ + nrOfVideosInMessage++ + case "audio_url", "audio": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) + if err != nil { + log.Error().Msgf("Failed encoding image: %s", err) + continue CONTENT + } + input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff + audioIndex++ + nrOfAudiosInMessage++ + case "image_url", "image": + // Decode content as base64 either if it's an URL or base64 text + base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) + if err != nil { + log.Error().Msgf("Failed encoding image: %s", err) + continue CONTENT + } + + input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff + + imgIndex++ + nrOfImgsInMessage++ + } + } + + input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ + TotalImages: imgIndex, + TotalVideos: vidIndex, + TotalAudios: audioIndex, + ImagesInMessage: nrOfImgsInMessage, + VideosInMessage: nrOfVideosInMessage, + AudiosInMessage: nrOfAudiosInMessage, + }, textContent) + } + } + + if input.RepeatPenalty != 0 { + config.RepeatPenalty = input.RepeatPenalty + } + + if input.FrequencyPenalty != 0 { + config.FrequencyPenalty = input.FrequencyPenalty + } + + if input.PresencePenalty != 0 { + config.PresencePenalty = input.PresencePenalty + } + + if input.Keep != 0 { + config.Keep = input.Keep + } + + if input.Batch != 0 { + config.Batch = input.Batch + } + + if input.IgnoreEOS { + config.IgnoreEOS = input.IgnoreEOS + } + + if input.Seed != nil { + config.Seed = input.Seed + } + + if input.TypicalP != nil { + config.TypicalP = input.TypicalP + } + + log.Debug().Str("input.Input", fmt.Sprintf("%+v", input.Input)) + + switch inputs := input.Input.(type) { + case string: + if inputs != "" { + config.InputStrings = append(config.InputStrings, inputs) + } + case []interface{}: + for _, pp := range inputs { + switch i := pp.(type) { + case string: + config.InputStrings = append(config.InputStrings, i) + case []interface{}: + tokens := []int{} + for _, ii := range i { + tokens = append(tokens, int(ii.(float64))) + } + config.InputToken = append(config.InputToken, tokens) + } + } + } + + // Can be either a string or an object + switch fnc := input.FunctionCall.(type) { + case string: + if fnc != "" { + config.SetFunctionCallString(fnc) + } + case map[string]interface{}: + var name string + n, exists := fnc["name"] + if exists { + nn, e := n.(string) + if e { + name = nn + } + } + config.SetFunctionCallNameString(name) + } + + switch p := input.Prompt.(type) { + case string: + config.PromptStrings = append(config.PromptStrings, p) + case []interface{}: + for _, pp := range p { + if s, ok := pp.(string); ok { + config.PromptStrings = append(config.PromptStrings, s) + } + } + } + + if config.Validate() { + return nil + } + return fmt.Errorf("unable to validate configuration after merging") +} diff --git a/core/http/routes/elevenlabs.go b/core/http/routes/elevenlabs.go index 73387c7bb76..9e735bb11b6 100644 --- a/core/http/routes/elevenlabs.go +++ b/core/http/routes/elevenlabs.go @@ -4,17 +4,26 @@ import ( "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/elevenlabs" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" ) func RegisterElevenLabsRoutes(app *fiber.App, + re *middleware.RequestExtractor, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) { // Elevenlabs - app.Post("/v1/text-to-speech/:voice-id", elevenlabs.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/text-to-speech/:voice-id", + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsTTSRequest) }), + elevenlabs.TTSEndpoint(cl, ml, appConfig)) - app.Post("/v1/sound-generation", elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) + app.Post("/v1/sound-generation", + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_SOUND_GENERATION)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.ElevenLabsSoundGenerationRequest) }), + elevenlabs.SoundGenerationEndpoint(cl, ml, appConfig)) } diff --git a/core/http/routes/jina.go b/core/http/routes/jina.go index 93125e6cb91..1f7a1a7c3e4 100644 --- a/core/http/routes/jina.go +++ b/core/http/routes/jina.go @@ -3,16 +3,22 @@ package routes import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/jina" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" "github.com/gofiber/fiber/v2" "github.com/mudler/LocalAI/pkg/model" ) func RegisterJINARoutes(app *fiber.App, + re *middleware.RequestExtractor, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) { // POST endpoint to mimic the reranking - app.Post("/v1/rerank", jina.JINARerankEndpoint(cl, ml, appConfig)) + app.Post("/v1/rerank", + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_RERANK)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.JINARerankRequest) }), + jina.JINARerankEndpoint(cl, ml, appConfig)) } diff --git a/core/http/routes/localai.go b/core/http/routes/localai.go index a2ef16a5aa9..ff970fbab04 100644 --- a/core/http/routes/localai.go +++ b/core/http/routes/localai.go @@ -5,13 +5,16 @@ import ( "github.com/gofiber/swagger" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" + "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" ) func RegisterLocalAIRoutes(app *fiber.App, + requestExtractor *middleware.RequestExtractor, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, @@ -33,7 +36,10 @@ func RegisterLocalAIRoutes(app *fiber.App, app.Get("/models/jobs", modelGalleryEndpointService.GetAllStatusEndpoint()) } - app.Post("/tts", localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/tts", + requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)), + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }), + localai.TTSEndpoint(cl, ml, appConfig)) // Stores sl := model.NewModelLoader("") @@ -46,7 +52,8 @@ func RegisterLocalAIRoutes(app *fiber.App, app.Get("/metrics", localai.LocalAIMetricsEndpoint()) } - // Experimental Backend Statistics Module + // Backend Statistics Module + // TODO: Should these use standard middlewares? Refactor later, they are extremely simple. backendMonitorService := services.NewBackendMonitorService(ml, cl, appConfig) // Split out for now app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitorService)) app.Post("/backend/shutdown", localai.BackendShutdownEndpoint(backendMonitorService)) @@ -66,6 +73,9 @@ func RegisterLocalAIRoutes(app *fiber.App, app.Get("/system", localai.SystemInformations(ml, appConfig)) // misc - app.Post("/v1/tokenize", localai.TokenizeEndpoint(cl, ml, appConfig)) + app.Post("/v1/tokenize", + requestExtractor.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TOKENIZE)), + requestExtractor.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TokenizeRequest) }), + localai.TokenizeEndpoint(cl, ml, appConfig)) } diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 081daf70d80..91aff3adb10 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -5,22 +5,50 @@ import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/openai" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" ) func RegisterOpenAIRoutes(app *fiber.App, + re *middleware.RequestExtractor, cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) { // openAI compatible API endpoint // chat - app.Post("/v1/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) - app.Post("/chat/completions", openai.ChatEndpoint(cl, ml, appConfig)) + chatChain := []fiber.Handler{ + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_CHAT)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + re.SetOpenAIRequest, + openai.ChatEndpoint(cl, ml, appConfig), + } + app.Post("/v1/chat/completions", chatChain...) + app.Post("/chat/completions", chatChain...) // edit - app.Post("/v1/edits", openai.EditEndpoint(cl, ml, appConfig)) - app.Post("/edits", openai.EditEndpoint(cl, ml, appConfig)) + editChain := []fiber.Handler{ + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EDIT)), + re.BuildConstantDefaultModelNameMiddleware("gpt-4o"), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + re.SetOpenAIRequest, + openai.EditEndpoint(cl, ml, appConfig), + } + app.Post("/v1/edits", editChain...) + app.Post("/edits", editChain...) + + // completion + completionChain := []fiber.Handler{ + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_COMPLETION)), + re.BuildConstantDefaultModelNameMiddleware("gpt-4o"), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + re.SetOpenAIRequest, + openai.CompletionEndpoint(cl, ml, appConfig), + } + app.Post("/v1/completions", completionChain...) + app.Post("/completions", completionChain...) + app.Post("/v1/engines/:model/completions", completionChain...) // assistant app.Get("/v1/assistants", openai.ListAssistantsEndpoint(cl, ml, appConfig)) @@ -54,22 +82,37 @@ func RegisterOpenAIRoutes(app *fiber.App, app.Get("/v1/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) app.Get("/files/:file_id/content", openai.GetFilesContentsEndpoint(cl, appConfig)) - // completion - app.Post("/v1/completions", openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/completions", openai.CompletionEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/completions", openai.CompletionEndpoint(cl, ml, appConfig)) - // embeddings - app.Post("/v1/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) - app.Post("/v1/engines/:model/embeddings", openai.EmbeddingsEndpoint(cl, ml, appConfig)) + embeddingChain := []fiber.Handler{ + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_EMBEDDINGS)), + re.BuildConstantDefaultModelNameMiddleware("gpt-4o"), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + re.SetOpenAIRequest, + openai.EmbeddingsEndpoint(cl, ml, appConfig), + } + app.Post("/v1/embeddings", embeddingChain...) + app.Post("/embeddings", embeddingChain...) + app.Post("/v1/engines/:model/embeddings", embeddingChain...) // audio - app.Post("/v1/audio/transcriptions", openai.TranscriptEndpoint(cl, ml, appConfig)) - app.Post("/v1/audio/speech", localai.TTSEndpoint(cl, ml, appConfig)) + app.Post("/v1/audio/transcriptions", + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TRANSCRIPT)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + re.SetOpenAIRequest, + openai.TranscriptEndpoint(cl, ml, appConfig), + ) + + app.Post("/v1/audio/speech", + re.BuildFilteredFirstAvailableDefaultModel(config.BuildUsecaseFilterFn(config.FLAG_TTS)), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.TTSRequest) }), + localai.TTSEndpoint(cl, ml, appConfig)) // images - app.Post("/v1/images/generations", openai.ImageEndpoint(cl, ml, appConfig)) + app.Post("/v1/images/generations", + re.BuildConstantDefaultModelNameMiddleware(model.StableDiffusionBackend), + re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }), + re.SetOpenAIRequest, + openai.ImageEndpoint(cl, ml, appConfig)) if appConfig.ImageDir != "" { app.Static("/generated-images", appConfig.ImageDir) diff --git a/core/schema/elevenlabs.go b/core/schema/elevenlabs.go index 119e0a58886..df8a8d7c4a8 100644 --- a/core/schema/elevenlabs.go +++ b/core/schema/elevenlabs.go @@ -1,8 +1,9 @@ package schema type ElevenLabsTTSRequest struct { - Text string `json:"text" yaml:"text"` - ModelID string `json:"model_id" yaml:"model_id"` + Text string `json:"text" yaml:"text"` + ModelID string `json:"model_id" yaml:"model_id"` + LanguageCode string `json:"language_code" yaml:"language_code"` } type ElevenLabsSoundGenerationRequest struct { @@ -12,3 +13,17 @@ type ElevenLabsSoundGenerationRequest struct { Temperature *float32 `json:"prompt_influence,omitempty" yaml:"prompt_influence,omitempty"` DoSample *bool `json:"do_sample,omitempty" yaml:"do_sample,omitempty"` } + +func (elttsr *ElevenLabsTTSRequest) ModelName(s *string) string { + if s != nil { + elttsr.ModelID = *s + } + return elttsr.ModelID +} + +func (elsgr *ElevenLabsSoundGenerationRequest) ModelName(s *string) string { + if s != nil { + elsgr.ModelID = *s + } + return elsgr.ModelID +} diff --git a/core/schema/jina.go b/core/schema/jina.go index 7f80689cb2e..63d24556fe9 100644 --- a/core/schema/jina.go +++ b/core/schema/jina.go @@ -2,10 +2,11 @@ package schema // RerankRequest defines the structure of the request payload type JINARerankRequest struct { - Model string `json:"model"` + BasicModelRequest Query string `json:"query"` Documents []string `json:"documents"` TopN int `json:"top_n"` + Backend string `json:"backend"` } // DocumentResult represents a single document result diff --git a/core/schema/localai.go b/core/schema/localai.go index cdc3e5b0784..7cb4a0cb971 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -7,11 +7,11 @@ import ( ) type BackendMonitorRequest struct { - Model string `json:"model" yaml:"model"` + BasicModelRequest } type TokenMetricsRequest struct { - Model string `json:"model" yaml:"model"` + BasicModelRequest } type BackendMonitorResponse struct { @@ -27,11 +27,11 @@ type GalleryResponse struct { // @Description TTS request body type TTSRequest struct { - Model string `json:"model" yaml:"model"` // model name or full path - Input string `json:"input" yaml:"input"` // text input - Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id - Backend string `json:"backend" yaml:"backend"` - Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model + BasicModelRequest // model name or full path + Input string `json:"input" yaml:"input"` // text input + Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id + Backend string `json:"backend" yaml:"backend"` + Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model } type StoresSet struct { diff --git a/core/schema/prediction.go b/core/schema/prediction.go index 18d2782bb4b..15785f1916d 100644 --- a/core/schema/prediction.go +++ b/core/schema/prediction.go @@ -3,7 +3,7 @@ package schema type PredictionOptions struct { // Also part of the OpenAI official spec - Model string `json:"model" yaml:"model"` + BasicModelRequest `yaml:",inline"` // Also part of the OpenAI official spec Language string `json:"language"` diff --git a/core/schema/request.go b/core/schema/request.go new file mode 100644 index 00000000000..f55f39200ca --- /dev/null +++ b/core/schema/request.go @@ -0,0 +1,22 @@ +package schema + +// This file and type represent a generic request to LocalAI - as opposed to requests to LocalAI-specific endpoints, which live in localai.go +type LocalAIRequest interface { + ModelName(*string) string +} + +type BasicModelRequest struct { + Model string `json:"model" yaml:"model"` + // TODO: Should this also include the following fields from the OpenAI side of the world? + // If so, changes should be made to core/http/middleware/request.go to match + + // Context context.Context `json:"-"` + // Cancel context.CancelFunc `json:"-"` +} + +func (bmr *BasicModelRequest) ModelName(s *string) string { + if s != nil { + bmr.Model = *s + } + return bmr.Model +} diff --git a/core/schema/tokenize.go b/core/schema/tokenize.go index 3770cc5affd..e481f186333 100644 --- a/core/schema/tokenize.go +++ b/core/schema/tokenize.go @@ -1,8 +1,8 @@ package schema type TokenizeRequest struct { + BasicModelRequest Content string `json:"content"` - Model string `json:"model"` } type TokenizeResponse struct { diff --git a/core/services/list_models.go b/core/services/list_models.go index ef555d22196..45c05f5828c 100644 --- a/core/services/list_models.go +++ b/core/services/list_models.go @@ -49,3 +49,15 @@ func ListModels(bcl *config.BackendConfigLoader, ml *model.ModelLoader, filter c return dataModels, nil } + +func CheckIfModelExists(bcl *config.BackendConfigLoader, ml *model.ModelLoader, modelName string, looseFilePolicy LooseFilePolicy) (bool, error) { + filter, err := config.BuildNameFilterFn(modelName) + if err != nil { + return false, err + } + models, err := ListModels(bcl, ml, filter, looseFilePolicy) + if err != nil { + return false, err + } + return (len(models) > 0), nil +} diff --git a/core/startup/startup.go b/core/startup/startup.go index 17e54bc0603..60c2e04caed 100644 --- a/core/startup/startup.go +++ b/core/startup/startup.go @@ -147,13 +147,7 @@ func Startup(opts ...config.AppOption) (*config.BackendConfigLoader, *model.Mode if options.LoadToMemory != nil { for _, m := range options.LoadToMemory { - cfg, err := cl.LoadBackendConfigFileByName(m, options.ModelPath, - config.LoadOptionDebug(options.Debug), - config.LoadOptionThreads(options.Threads), - config.LoadOptionContextSize(options.ContextSize), - config.LoadOptionF16(options.F16), - config.ModelPath(options.ModelPath), - ) + cfg, err := cl.LoadBackendConfigFileByNameDefaultOptions(m, options) if err != nil { return nil, nil, nil, err } diff --git a/docs/content/docs/getting-started/quickstart.md b/docs/content/docs/getting-started/quickstart.md index 9ccc0faaaea..4e14c505ee9 100644 --- a/docs/content/docs/getting-started/quickstart.md +++ b/docs/content/docs/getting-started/quickstart.md @@ -30,6 +30,19 @@ For a full list of options, refer to the [Installer Options]({{% relref "docs/ad Binaries can also be [manually downloaded]({{% relref "docs/reference/binaries" %}}). +## Using Homebrew on MacOS + +{{% alert icon="⚠️" %}} +The Homebrew formula currently doesn't have the same options than the bash script +{{% /alert %}} + +You can install Homebrew's [LocalAI](https://formulae.brew.sh/formula/localai) with the following command: + +``` +brew install localai +``` + + ## Using Container Images or Kubernetes LocalAI is available as a container image compatible with various container engines such as Docker, Podman, and Kubernetes. Container images are published on [quay.io](https://quay.io/repository/go-skynet/local-ai?tab=tags&tag=latest) and [Docker Hub](https://hub.docker.com/r/localai/localai). diff --git a/go.mod b/go.mod index e969d50895e..bd13f8c9671 100644 --- a/go.mod +++ b/go.mod @@ -231,7 +231,7 @@ require ( github.com/moby/sys/sequential v0.5.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/mr-tron/base58 v1.2.0 // indirect - github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb + github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc github.com/mudler/water v0.0.0-20221010214108-8c7313014ce0 // indirect github.com/muesli/reflow v0.3.0 // indirect github.com/muesli/termenv v0.15.2 // indirect diff --git a/go.sum b/go.sum index 9047d4eb72c..05a59952edb 100644 --- a/go.sum +++ b/go.sum @@ -500,6 +500,8 @@ github.com/mudler/go-piper v0.0.0-20240315144837-9d0100873a7d h1:8udOFrDf/I83JL0 github.com/mudler/go-piper v0.0.0-20240315144837-9d0100873a7d/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb h1:5qcuxQEpAqeV4ftV5nUt3/hB/RoTXq3MaaauOAedyXo= github.com/mudler/go-piper v0.0.0-20241022074816-3854e0221ffb/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= +github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc h1:RxwneJl1VgvikiX28EkpdAyL4yQVnJMrbquKospjHyA= +github.com/mudler/go-piper v0.0.0-20241023091659-2494246fd9fc/go.mod h1:O7SwdSWMilAWhBZMK9N9Y/oBDyMMzshE3ju8Xkexwig= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82 h1:FVT07EI8njvsD4tC2Hw8Xhactp5AWhsQWD4oTeQuSAU= github.com/mudler/go-processmanager v0.0.0-20240820160718-8b802d3ecf82/go.mod h1:Urp7LG5jylKoDq0663qeBh0pINGcRl35nXdKx82PSoU= github.com/mudler/go-stable-diffusion v0.0.0-20240429204715-4a3cd6aeae6f h1:cxtMSRkUfy+mjIQ3yMrU0txwQ4It913NEN4m1H8WWgo= diff --git a/pkg/model/initializers.go b/pkg/model/initializers.go index bd668ec25a2..4988de25722 100644 --- a/pkg/model/initializers.go +++ b/pkg/model/initializers.go @@ -457,7 +457,7 @@ func (ml *ModelLoader) ListAvailableBackends(assetdir string) ([]string, error) func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err error) { o := NewOptions(opts...) - log.Info().Msgf("Loading model '%s' with backend %s", o.modelID, o.backendString) + log.Info().Str("modelID", o.modelID).Str("backend", o.backendString).Str("o.model", o.model).Msg("BackendLoader starting") backend := strings.ToLower(o.backendString) if realBackend, exists := Aliases[backend]; exists { diff --git a/pkg/model/loader_options.go b/pkg/model/loader_options.go index e7fd06de9f4..c151d53b1e8 100644 --- a/pkg/model/loader_options.go +++ b/pkg/model/loader_options.go @@ -56,6 +56,14 @@ func WithBackendString(backend string) Option { } } +func WithDefaultBackendString(backend string) Option { + return func(o *Options) { + if o.backendString == "" { + o.backendString = backend + } + } +} + func WithModel(modelFile string) Option { return func(o *Options) { o.model = modelFile diff --git a/tests/e2e-aio/e2e_test.go b/tests/e2e-aio/e2e_test.go index a9c5549742a..9156428b30e 100644 --- a/tests/e2e-aio/e2e_test.go +++ b/tests/e2e-aio/e2e_test.go @@ -235,7 +235,9 @@ var _ = Describe("E2E test", func() { modelName := "jina-reranker-v1-base-en" req := schema.JINARerankRequest{ - Model: modelName, + BasicModelRequest: schema.BasicModelRequest{ + Model: modelName, + }, Query: "Organic skincare products for sensitive skin", Documents: []string{ "Eco-friendly kitchenware for modern homes", @@ -256,12 +258,15 @@ var _ = Describe("E2E test", func() { Expect(err).To(BeNil()) Expect(serialized).ToNot(BeNil()) + GinkgoWriter.Printf("Reranker Request Body JSON: %q\n", string(serialized)) + rerankerEndpoint := apiEndpoint + "/rerank" resp, err := http.Post(rerankerEndpoint, "application/json", bytes.NewReader(serialized)) Expect(err).To(BeNil()) Expect(resp).ToNot(BeNil()) body, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred()) + GinkgoWriter.Printf("Reranker Response Body JSON: %q\n", string(body)) Expect(resp.StatusCode).To(Equal(200), fmt.Sprintf("body: %s, response: %+v", body, resp)) deserializedResponse := schema.JINARerankResponse{}