Skip to content

Commit

Permalink
chore(model-loader): increase test coverage of model loader (#3433)
Browse files Browse the repository at this point in the history
chore(model-loader): increase coverage of model loader

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Aug 30, 2024
1 parent 69a3b22 commit 607fd06
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 5 deletions.
33 changes: 31 additions & 2 deletions pkg/model/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"strings"
"sync"
"time"

"github.com/mudler/LocalAI/pkg/templates"

Expand Down Expand Up @@ -102,6 +103,18 @@ FILE:
return models, nil
}

func (ml *ModelLoader) ListModels() []*Model {
ml.mu.Lock()
defer ml.mu.Unlock()

models := []*Model{}
for _, model := range ml.models {
models = append(models, model)
}

return models
}

func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (*Model, error)) (*Model, error) {
ml.mu.Lock()
defer ml.mu.Unlock()
Expand All @@ -120,7 +133,12 @@ func (ml *ModelLoader) LoadModel(modelName string, loader func(string, string) (
return nil, err
}

if model == nil {
return nil, fmt.Errorf("loader didn't return a model")
}

ml.models[modelName] = model

return model, nil
}

Expand All @@ -146,11 +164,22 @@ func (ml *ModelLoader) CheckIsLoaded(s string) *Model {
}

log.Debug().Msgf("Model already loaded in memory: %s", s)
alive, err := m.GRPC(false, ml.wd).HealthCheck(context.Background())
client := m.GRPC(false, ml.wd)

log.Debug().Msgf("Checking model availability (%s)", s)
cTimeout, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

alive, err := client.HealthCheck(cTimeout)
if !alive {
log.Warn().Msgf("GRPC Model not responding: %s", err.Error())
log.Warn().Msgf("Deleting the process in order to recreate it")
if !ml.grpcProcesses[s].IsAlive() {
process, exists := ml.grpcProcesses[s]
if !exists {
log.Error().Msgf("Process not found for '%s' and the model is not responding anymore !", s)
return m
}
if !process.IsAlive() {
log.Debug().Msgf("GRPC Process is not responding: %s", s)
// stop and delete the process, this forces to re-load the model and re-create again the service
err := ml.deleteProcess(s)
Expand Down
105 changes: 105 additions & 0 deletions pkg/model/loader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package model_test

import (
"errors"
"os"
"path/filepath"

"github.com/mudler/LocalAI/pkg/model"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("ModelLoader", func() {
var (
modelLoader *model.ModelLoader
modelPath string
mockModel *model.Model
)

BeforeEach(func() {
// Setup the model loader with a test directory
modelPath = "/tmp/test_model_path"
os.Mkdir(modelPath, 0755)
modelLoader = model.NewModelLoader(modelPath)
})

AfterEach(func() {
// Cleanup test directory
os.RemoveAll(modelPath)
})

Context("NewModelLoader", func() {
It("should create a new ModelLoader with an empty model map", func() {
Expect(modelLoader).ToNot(BeNil())
Expect(modelLoader.ModelPath).To(Equal(modelPath))
Expect(modelLoader.ListModels()).To(BeEmpty())
})
})

Context("ExistsInModelPath", func() {
It("should return true if a file exists in the model path", func() {
testFile := filepath.Join(modelPath, "test.model")
os.Create(testFile)
Expect(modelLoader.ExistsInModelPath("test.model")).To(BeTrue())
})

It("should return false if a file does not exist in the model path", func() {
Expect(modelLoader.ExistsInModelPath("nonexistent.model")).To(BeFalse())
})
})

Context("ListFilesInModelPath", func() {
It("should list all valid model files in the model path", func() {
os.Create(filepath.Join(modelPath, "test.model"))
os.Create(filepath.Join(modelPath, "README.md"))

files, err := modelLoader.ListFilesInModelPath()
Expect(err).To(BeNil())
Expect(files).To(ContainElement("test.model"))
Expect(files).ToNot(ContainElement("README.md"))
})
})

Context("LoadModel", func() {
It("should load a model and keep it in memory", func() {
mockModel = model.NewModel("test.model")

mockLoader := func(modelName, modelFile string) (*model.Model, error) {
return mockModel, nil
}

model, err := modelLoader.LoadModel("test.model", mockLoader)
Expect(err).To(BeNil())
Expect(model).To(Equal(mockModel))
Expect(modelLoader.CheckIsLoaded("test.model")).To(Equal(mockModel))
})

It("should return an error if loading the model fails", func() {
mockLoader := func(modelName, modelFile string) (*model.Model, error) {
return nil, errors.New("failed to load model")
}

model, err := modelLoader.LoadModel("test.model", mockLoader)
Expect(err).To(HaveOccurred())
Expect(model).To(BeNil())
})
})

Context("ShutdownModel", func() {
It("should shutdown a loaded model", func() {
mockModel = model.NewModel("test.model")

mockLoader := func(modelName, modelFile string) (*model.Model, error) {
return mockModel, nil
}

_, err := modelLoader.LoadModel("test.model", mockLoader)
Expect(err).To(BeNil())

err = modelLoader.ShutdownModel("test.model")
Expect(err).To(BeNil())
Expect(modelLoader.CheckIsLoaded("test.model")).To(BeNil())
})
})
})
5 changes: 2 additions & 3 deletions pkg/model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ func (m *Model) GRPC(parallel bool, wd *WatchDog) grpc.Backend {
enableWD = true
}

client := grpc.NewClient(m.address, parallel, wd, enableWD)
m.client = client
return client
m.client = grpc.NewClient(m.address, parallel, wd, enableWD)
return m.client
}

0 comments on commit 607fd06

Please sign in to comment.