Skip to content

Commit

Permalink
feat(startup): fetch model definition remotely (#1654)
Browse files Browse the repository at this point in the history
  • Loading branch information
mudler authored Jan 27, 2024
1 parent f928899 commit 6ac5d81
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 6 deletions.
2 changes: 1 addition & 1 deletion api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader,
log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())

startup.PreloadModelsConfigurations(options.Loader.ModelPath, options.ModelsURL...)
startup.PreloadModelsConfigurations(options.ModelLibraryURL, options.Loader.ModelPath, options.ModelsURL...)

cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
Expand Down
8 changes: 8 additions & 0 deletions api/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ type Option struct {
ApiKeys []string
Metrics *metrics.Metrics

ModelLibraryURL string

Galleries []gallery.Gallery

BackendAssets embed.FS
Expand Down Expand Up @@ -78,6 +80,12 @@ func WithCors(b bool) AppOption {
}
}

func WithModelLibraryURL(url string) AppOption {
return func(o *Option) {
o.ModelLibraryURL = url
}
}

var EnableWatchDog = func(o *Option) {
o.WatchDog = true
}
Expand Down
15 changes: 15 additions & 0 deletions embedded/embedded.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"slices"
"strings"

"github.com/go-skynet/LocalAI/pkg/downloader"

"github.com/go-skynet/LocalAI/pkg/assets"
"gopkg.in/yaml.v3"
)
Expand All @@ -30,6 +32,19 @@ func init() {
yaml.Unmarshal(modelLibrary, &modelShorteners)
}

func GetRemoteLibraryShorteners(url string) (map[string]string, error) {
remoteLibrary := map[string]string{}

err := downloader.GetURI(url, func(_ string, i []byte) error {
return yaml.Unmarshal(i, &remoteLibrary)
})
if err != nil {
return nil, fmt.Errorf("error downloading remote library: %s", err.Error())
}

return remoteLibrary, err
}

// ExistsInModelsLibrary checks if a model exists in the embedded models library
func ExistsInModelsLibrary(s string) bool {
f := fmt.Sprintf("%s.yaml", s)
Expand Down
11 changes: 11 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ import (
"github.com/urfave/cli/v2"
)

const (
remoteLibraryURL = "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
)

func main() {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr})
// clean up process
Expand Down Expand Up @@ -94,6 +98,12 @@ func main() {
Usage: "JSON list of galleries",
EnvVars: []string{"GALLERIES"},
},
&cli.StringFlag{
Name: "remote-library",
Usage: "A LocalAI remote library URL",
EnvVars: []string{"REMOTE_LIBRARY"},
Value: remoteLibraryURL,
},
&cli.StringFlag{
Name: "preload-models",
Usage: "A List of models to apply in JSON at start",
Expand Down Expand Up @@ -219,6 +229,7 @@ For a list of compatible model, check out: https://localai.io/model-compatibilit
options.WithAudioDir(ctx.String("audio-path")),
options.WithF16(ctx.Bool("f16")),
options.WithStringGalleries(ctx.String("galleries")),
options.WithModelLibraryURL(ctx.String("remote-library")),
options.WithDisableMessage(false),
options.WithCors(ctx.Bool("cors")),
options.WithCorsAllowOrigins(ctx.String("cors-allow-origins")),
Expand Down
16 changes: 14 additions & 2 deletions pkg/startup/model_preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,22 @@ import (
// PreloadModelsConfigurations will preload models from the given list of URLs
// It will download the model if it is not already present in the model path
// It will also try to resolve if the model is an embedded model YAML configuration
func PreloadModelsConfigurations(modelPath string, models ...string) {
func PreloadModelsConfigurations(modelLibraryURL string, modelPath string, models ...string) {
for _, url := range models {
url = embedded.ModelShortURL(url)

// As a best effort, try to resolve the model from the remote library
// if it's not resolved we try with the other method below
if modelLibraryURL != "" {
lib, err := embedded.GetRemoteLibraryShorteners(modelLibraryURL)
if err == nil {
if lib[url] != "" {
log.Debug().Msgf("[startup] model configuration is defined remotely: %s (%s)", url, lib[url])
url = lib[url]
}
}
}

url = embedded.ModelShortURL(url)
switch {
case embedded.ExistsInModelsLibrary(url):
modelYAML, err := embedded.ResolveContent(url)
Expand Down
22 changes: 19 additions & 3 deletions pkg/startup/model_preload_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,29 @@ import (
var _ = Describe("Preload test", func() {

Context("Preloading from strings", func() {
It("loads from remote url", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
libraryURL := "https://raw.githubusercontent.com/mudler/LocalAI/master/embedded/model_library.yaml"
fileName := fmt.Sprintf("%s.yaml", "1701d57f28d47552516c2b6ecc3cc719")

PreloadModelsConfigurations(libraryURL, tmpdir, "phi-2")

resultFile := filepath.Join(tmpdir, fileName)

content, err := os.ReadFile(resultFile)
Expect(err).ToNot(HaveOccurred())

Expect(string(content)).To(ContainSubstring("name: phi-2"))
})

It("loads from embedded full-urls", func() {
tmpdir, err := os.MkdirTemp("", "")
Expect(err).ToNot(HaveOccurred())
url := "https://raw.githubusercontent.com/mudler/LocalAI/master/examples/configurations/phi-2.yaml"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))

PreloadModelsConfigurations(tmpdir, url)
PreloadModelsConfigurations("", tmpdir, url)

resultFile := filepath.Join(tmpdir, fileName)

Expand All @@ -35,7 +51,7 @@ var _ = Describe("Preload test", func() {
Expect(err).ToNot(HaveOccurred())
url := "phi-2"

PreloadModelsConfigurations(tmpdir, url)
PreloadModelsConfigurations("", tmpdir, url)

entry, err := os.ReadDir(tmpdir)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -53,7 +69,7 @@ var _ = Describe("Preload test", func() {
url := "mistral-openorca"
fileName := fmt.Sprintf("%s.yaml", utils.MD5(url))

PreloadModelsConfigurations(tmpdir, url)
PreloadModelsConfigurations("", tmpdir, url)

resultFile := filepath.Join(tmpdir, fileName)

Expand Down

0 comments on commit 6ac5d81

Please sign in to comment.