Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(refactor): drop duplicated shutdown logics #3589

Merged
merged 3 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/http/routes/localai.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ func RegisterLocalAIRoutes(app *fiber.App,
}{Version: internal.PrintableVersion()})
})

app.Get("/system", auth, localai.SystemInformations(ml, appConfig))
app.Get("/system", localai.SystemInformations(ml, appConfig))

}
17 changes: 17 additions & 0 deletions pkg/model/filters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package model

import (
process "github.com/mudler/go-processmanager"
)

type GRPCProcessFilter = func(id string, p *process.Process) bool

func all(_ string, _ *process.Process) bool {
return true
}

func allExcept(s string) GRPCProcessFilter {
return func(id string, p *process.Process) bool {
return id != s
}
}
16 changes: 6 additions & 10 deletions pkg/model/initializers.go
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
} else {
grpcProcess := backendPath(o.assetDir, backend)
if err := utils.VerifyPath(grpcProcess, o.assetDir); err != nil {
return nil, fmt.Errorf("grpc process not found in assetdir: %s", err.Error())
return nil, fmt.Errorf("refering to a backend not in asset dir: %s", err.Error())
}

if autoDetect {
Expand All @@ -332,7 +332,7 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string

// Check if the file exists
if _, err := os.Stat(grpcProcess); os.IsNotExist(err) {
return nil, fmt.Errorf("grpc process not found: %s. some backends(stablediffusion, tts) require LocalAI compiled with GO_TAGS", grpcProcess)
return nil, fmt.Errorf("backend not found: %s", grpcProcess)
}

serverAddress, err := getFreeAddress()
Expand All @@ -355,6 +355,8 @@ func (ml *ModelLoader) grpcModel(backend string, o *Options) func(string, string
client = NewModel(serverAddress)
}

log.Debug().Msgf("Wait for the service to start up")

// Wait for the service to start up
ready := false
for i := 0; i < o.grpcAttempts; i++ {
Expand Down Expand Up @@ -413,10 +415,8 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
}

if o.singleActiveBackend {
ml.mu.Lock()
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
err := ml.StopAllExcept(o.model)
ml.mu.Unlock()
err := ml.StopGRPC(allExcept(o.model))
if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel")
return nil, err
Expand Down Expand Up @@ -444,26 +444,22 @@ func (ml *ModelLoader) BackendLoader(opts ...Option) (client grpc.Backend, err e
func (ml *ModelLoader) GreedyLoader(opts ...Option) (grpc.Backend, error) {
o := NewOptions(opts...)

ml.mu.Lock()

// Return earlier if we have a model already loaded
// (avoid looping through all the backends)
if m := ml.CheckIsLoaded(o.model); m != nil {
log.Debug().Msgf("Model '%s' already loaded", o.model)
ml.mu.Unlock()

return m.GRPC(o.parallelRequests, ml.wd), nil
}

// If we can have only one backend active, kill all the others (except external backends)
if o.singleActiveBackend {
log.Debug().Msgf("Stopping all backends except '%s'", o.model)
err := ml.StopAllExcept(o.model)
err := ml.StopGRPC(allExcept(o.model))
if err != nil {
log.Error().Err(err).Str("keptModel", o.model).Msg("error while shutting down all backends except for the keptModel - greedyloader continuing")
}
}
ml.mu.Unlock()

var err error

Expand Down
7 changes: 4 additions & 3 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ func (ml *ModelLoader) ListModels() []*Model {
}

func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) {
ml.mu.Lock()
defer ml.mu.Unlock()

// Check if we already have a loaded model
if model := ml.CheckIsLoaded(modelName); model != nil {
return model, nil
Expand All @@ -139,6 +136,8 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
return nil, fmt.Errorf("loader didn't return a model")
}

ml.mu.Lock()
defer ml.mu.Unlock()
ml.models[modelName] = model

return model, nil
Expand Down Expand Up @@ -168,6 +167,8 @@ func (ml *ModelLoader) ShutdownModel(modelName string) error {
}

func (ml *ModelLoader) CheckIsLoaded(s string) *Model {
ml.mu.Lock()
defer ml.mu.Unlock()
m, ok := ml.models[s]
if !ok {
return nil
Expand Down
28 changes: 4 additions & 24 deletions pkg/model/process.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,12 @@ import (
"strconv"
"strings"
"syscall"
"time"

"github.com/hpcloud/tail"
process "github.com/mudler/go-processmanager"
"github.com/rs/zerolog/log"
)

func (ml *ModelLoader) StopAllExcept(s string) error {
return ml.StopGRPC(func(id string, p *process.Process) bool {
if id == s {
return false
}

for ml.models[id].GRPC(false, ml.wd).IsBusy() {
log.Debug().Msgf("%s busy. Waiting.", id)
time.Sleep(2 * time.Second)
}
log.Debug().Msgf("[single-backend] Stopping %s", id)
return true
})
}

func (ml *ModelLoader) deleteProcess(s string) error {
if _, exists := ml.grpcProcesses[s]; exists {
if err := ml.grpcProcesses[s].Stop(); err != nil {
Expand All @@ -42,28 +26,24 @@ func (ml *ModelLoader) deleteProcess(s string) error {
return nil
}

type GRPCProcessFilter = func(id string, p *process.Process) bool

func includeAllProcesses(_ string, _ *process.Process) bool {
return true
}

func (ml *ModelLoader) StopGRPC(filter GRPCProcessFilter) error {
var err error = nil
for k, p := range ml.grpcProcesses {
if filter(k, p) {
e := ml.deleteProcess(k)
e := ml.ShutdownModel(k)
err = errors.Join(err, e)
}
}
return err
}

func (ml *ModelLoader) StopAllGRPC() error {
return ml.StopGRPC(includeAllProcesses)
return ml.StopGRPC(all)
}

func (ml *ModelLoader) GetGRPCPID(id string) (int, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
p, exists := ml.grpcProcesses[id]
if !exists {
return -1, fmt.Errorf("no grpc backend found for %s", id)
Expand Down
Loading