diff --git a/.gitignore b/.gitignore index 07b8dbfff74..075639cee17 100644 --- a/.gitignore +++ b/.gitignore @@ -35,7 +35,7 @@ release/ .idea # Generated during build -backend-assets/* +**/backend-assets/* !backend-assets/.keep prepare /ggml-metal.metal diff --git a/configuration/roles.json b/configuration/roles.json new file mode 100644 index 00000000000..c6145546502 --- /dev/null +++ b/configuration/roles.json @@ -0,0 +1,12 @@ +{ + "admin": ["*"], + "llm-user": ["POST|/chat/completions", "POST|/edits", "POST|/completions", "POST|/embeddings", "POST|/rerank", "GET|/models"], + "audio-user": ["POST|/audio/transcriptions", "POST|/audio/speech", "POST|/tts", "POST|/text-to-speech"], + "image-user": ["POST|/images/generations"], + "ui": ["GET|/", "GET|/browse", "GET|/browse/", "POST|/browse/search/models", + "GET|/browse/job/progress", "GET|/browse/job", + "GET|/chat", "GET|/chat/", "GET|/chat/:model", + "GET|/text2image", "GET|/text2image/", "GET|/text2image/:model", + "GET|/tts", "GET|/tts/", "GET|/tts/:model"], + "user": ["ui", "llm-user", "audio-user", "image-user"] +} \ No newline at end of file diff --git a/core/cli/run.go b/core/cli/run.go index 6185627d6c9..56eae518bf7 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -37,12 +37,12 @@ type RunCMD struct { Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" default:"4" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"` ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" default:"512" help:"Default context size for models" group:"performance"` - Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` - CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` - CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` - UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` - APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` - DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"` + Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` + CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` + CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` + UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` + APIKeys map[string][]string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` + DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disable webui" group:"api"` ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"` SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time" group:"backends"` diff --git a/core/config/application_config.go b/core/config/application_config.go index 398418adade..6a6c35e7fa3 100644 --- a/core/config/application_config.go +++ b/core/config/application_config.go @@ -28,7 +28,8 @@ type ApplicationConfig struct { PreloadJSONModels string PreloadModelsFromPath string CORSAllowOrigins string - ApiKeys []string + ApiKeys map[string][]string // ApiKeys maps the key itself to a list of endpoints [or roles] that the key should be permitted to access + Roles map[string][]string // Roles is a simple "shortcut" mapping a name to a list of endpoints ModelLibraryURL string @@ -271,7 +272,7 @@ func WithDynamicConfigDirPollInterval(interval time.Duration) AppOption { } } -func WithApiKeys(apiKeys []string) AppOption { +func WithApiKeys(apiKeys map[string][]string) AppOption { return func(o *ApplicationConfig) { o.ApiKeys = apiKeys } diff --git a/core/http/app.go b/core/http/app.go index de31346b4b9..0a4b12764f5 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -3,7 +3,9 @@ package http import ( "embed" "errors" + "fmt" "net/http" + "slices" "strings" "github.com/go-skynet/LocalAI/pkg/utils" @@ -127,33 +129,48 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi // Auth middleware checking if API key is valid. If no API key is set, no auth is required. auth := func(c *fiber.Ctx) error { - if len(appConfig.ApiKeys) == 0 { - return c.Next() - } if len(appConfig.ApiKeys) == 0 { return c.Next() } + defaultCaseExists := len(appConfig.ApiKeys["_"]) > 0 + fmtPath := fmt.Sprintf("%s|%s", c.Route().Method, strings.Replace(c.Route().Path, "/v1", "", -1)) + authHeader := readAuthHeader(c) - if authHeader == "" { + if !defaultCaseExists && authHeader == "" { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Authorization header missing"}) } // If it's a bearer token authHeaderParts := strings.Split(authHeader, " ") if len(authHeaderParts) != 2 || authHeaderParts[0] != "Bearer" { - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + if !defaultCaseExists { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid Authorization header format"}) + } else { + authHeaderParts = []string{"", ""} + } } apiKey := authHeaderParts[1] - for _, key := range appConfig.ApiKeys { - if apiKey == key { - return c.Next() + if apiKey != "" { + for key, endpoints := range appConfig.ApiKeys { + if apiKey == key { + log.Trace().Str("key", key).Str("fmtPath", fmtPath).Msg("found a matching api key, checking permissions for fmtPath") + if slices.Contains(endpoints, "*") || slices.Contains(endpoints, fmtPath) { + return c.Next() + } + } } } - return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key"}) + // Check if this is a default-allow endpoint + if defaultCaseExists && slices.Contains(appConfig.ApiKeys["_"], fmtPath) { + log.Trace().Str("fmtPath", fmtPath).Msg("matching authorization key not found, but fmtPath is on the default allow list") + return c.Next() + } + + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"message": "Invalid API key", "fmtPath": fmtPath, "apiKey": apiKey}) } if appConfig.CORS { diff --git a/core/startup/config_file_watcher.go b/core/startup/config_file_watcher.go index 259446f1b32..d169ff6030c 100644 --- a/core/startup/config_file_watcher.go +++ b/core/startup/config_file_watcher.go @@ -6,6 +6,7 @@ import ( "os" "path" "path/filepath" + "slices" "time" "github.com/fsnotify/fsnotify" @@ -31,7 +32,12 @@ func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler handlers: make(map[string]fileHandler), appConfig: appConfig, } - err := c.Register("api_keys.json", readApiKeysJson(*appConfig), true) + + err := c.Register("roles.json", readRolesJson(*appConfig), true) + if err != nil { + log.Error().Err(err).Str("file", "roles.json").Msg("unable to register config file handler") + } + err = c.Register("api_keys.json", readApiKeysJson(*appConfig), true) if err != nil { log.Error().Err(err).Str("file", "api_keys.json").Msg("unable to register config file handler") } @@ -39,6 +45,7 @@ func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler if err != nil { log.Error().Err(err).Str("file", "external_backends.json").Msg("unable to register config file handler") } + return c } @@ -135,19 +142,69 @@ func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler { if len(fileContent) > 0 { // Parse JSON content from the file - var fileKeys []string + var fileKeys map[string][]string err := json.Unmarshal(fileContent, &fileKeys) if err != nil { - return err + // Try to deserialize the old, flat list format + var oldFileFormat []string + err := json.Unmarshal(fileContent, &oldFileFormat) + if err != nil { + log.Error().Err(err).Msg("unable to parse api_keys.json as any known format") + return err + } + log.Warn().Msg("unable to parse api_keys.json in modern format, defaulting all api keys to [\"ui\", \"user\"]") + for _, k := range oldFileFormat { + fileKeys[k] = []string{"ui", "user"} + } } - log.Trace().Int("numKeys", len(fileKeys)).Msg("discovered API keys from api keys dynamic config dile") + appConfig.ApiKeys = startupAppConfig.ApiKeys + if appConfig.ApiKeys == nil { + appConfig.ApiKeys = map[string][]string{} + } - appConfig.ApiKeys = append(startupAppConfig.ApiKeys, fileKeys...) + log.Trace().Int("numKeys", len(fileKeys)).Msg("discovered API keys from api keys dynamic config dile") + for key, rawFileEndpoints := range fileKeys { + appConfig.ApiKeys[key] = append(startupAppConfig.ApiKeys[key], rawFileEndpoints...) + } } else { log.Trace().Msg("no API keys discovered from dynamic config file") appConfig.ApiKeys = startupAppConfig.ApiKeys } + + // next, clean and process the ApiKeys for roles, duplicates, and * + // This is registered to run at startup, so will evaluate roles passed in as startupAppConfig + // quick version for now, this can be improved later + for key, endpoints := range appConfig.ApiKeys { + // Check if the starting point is enough to know the final answer + if slices.Contains(endpoints, "*") { + appConfig.ApiKeys[key] = []string{"*"} + continue + } + + for { // We loop around here a second time if we make a change -- this ensures we unroll nested roles + isClean := true + for role, roleEndpoints := range appConfig.Roles { + index := slices.Index(appConfig.ApiKeys[key], role) + if index != -1 { + appConfig.ApiKeys[key] = slices.Replace(appConfig.ApiKeys[key], index, index+1, roleEndpoints...) + isClean = false + } + } + if isClean { + break + } + } + // Check if we have a "*"" yet + if slices.Contains(appConfig.ApiKeys[key], "*") { + appConfig.ApiKeys[key] = []string{"*"} + continue + } + // At this point, Sort+Compact is a simple way to deduplicate the endpoint list, no matter how the roles overlap + slices.Sort(appConfig.ApiKeys[key]) + appConfig.ApiKeys[key] = slices.Compact(appConfig.ApiKeys[key]) + } + log.Trace().Int("numKeys", len(appConfig.ApiKeys)).Msg("total api keys after processing") return nil } @@ -155,6 +212,33 @@ func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler { return handler } +func readRolesJson(startupAppConfig config.ApplicationConfig) fileHandler { + handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error { + log.Debug().Msg("processing roles runtime update") + log.Trace().Int("numRoles", len(startupAppConfig.Roles)).Msg("roles provided at startup") + + if len(fileContent) > 0 { + // Parse JSON content from the file + var fileRoles map[string][]string // Roles is a simple "shortcut" mapping a name to a list of endpoints + err := json.Unmarshal(fileContent, &fileRoles) + if err != nil { + return err + } + + log.Trace().Int("numRoles", len(fileRoles)).Msg("discovered roles from roles dynamic config dile") + + appConfig.Roles = fileRoles + } else { + log.Trace().Msg("no roles discovered from dynamic config file") + appConfig.Roles = startupAppConfig.Roles + } + log.Trace().Int("numRoles", len(appConfig.Roles)).Msg("total roles after processing") + return nil + } + + return handler +} + func readExternalBackendsJson(startupAppConfig config.ApplicationConfig) fileHandler { handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error { log.Debug().Msg("processing external_backends.json")