Skip to content

Commit 16e5689

Browse files
Copilotmudler
andauthored
feat(importers): Add diffuser backend importer with ginkgo tests and UI support (#7316)
* Initial plan * Add diffuser backend importer with ginkgo tests Co-authored-by: mudler <[email protected]> * Finalize diffuser backend importer implementation Co-authored-by: mudler <[email protected]> * Add diffuser preferences to model-editor import section Co-authored-by: mudler <[email protected]> * Use gopkg.in/yaml.v3 for consistency in diffuser importer Co-authored-by: mudler <[email protected]> --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: mudler <[email protected]>
1 parent 2dd4229 commit 16e5689

File tree

7 files changed

+452
-13
lines changed

7 files changed

+452
-13
lines changed

core/gallery/importers/diffuser.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package importers
2+
3+
import (
4+
"encoding/json"
5+
"path/filepath"
6+
"strings"
7+
8+
"github.com/mudler/LocalAI/core/config"
9+
"github.com/mudler/LocalAI/core/gallery"
10+
"github.com/mudler/LocalAI/core/schema"
11+
"gopkg.in/yaml.v3"
12+
)
13+
14+
var _ Importer = &DiffuserImporter{}
15+
16+
type DiffuserImporter struct{}
17+
18+
func (i *DiffuserImporter) Match(details Details) bool {
19+
preferences, err := details.Preferences.MarshalJSON()
20+
if err != nil {
21+
return false
22+
}
23+
preferencesMap := make(map[string]any)
24+
err = json.Unmarshal(preferences, &preferencesMap)
25+
if err != nil {
26+
return false
27+
}
28+
29+
b, ok := preferencesMap["backend"].(string)
30+
if ok && b == "diffusers" {
31+
return true
32+
}
33+
34+
if details.HuggingFace != nil {
35+
for _, file := range details.HuggingFace.Files {
36+
if strings.Contains(file.Path, "model_index.json") ||
37+
strings.Contains(file.Path, "scheduler/scheduler_config.json") {
38+
return true
39+
}
40+
}
41+
}
42+
43+
return false
44+
}
45+
46+
func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error) {
47+
preferences, err := details.Preferences.MarshalJSON()
48+
if err != nil {
49+
return gallery.ModelConfig{}, err
50+
}
51+
preferencesMap := make(map[string]any)
52+
err = json.Unmarshal(preferences, &preferencesMap)
53+
if err != nil {
54+
return gallery.ModelConfig{}, err
55+
}
56+
57+
name, ok := preferencesMap["name"].(string)
58+
if !ok {
59+
name = filepath.Base(details.URI)
60+
}
61+
62+
description, ok := preferencesMap["description"].(string)
63+
if !ok {
64+
description = "Imported from " + details.URI
65+
}
66+
67+
backend := "diffusers"
68+
b, ok := preferencesMap["backend"].(string)
69+
if ok {
70+
backend = b
71+
}
72+
73+
pipelineType, ok := preferencesMap["pipeline_type"].(string)
74+
if !ok {
75+
pipelineType = "StableDiffusionPipeline"
76+
}
77+
78+
schedulerType, ok := preferencesMap["scheduler_type"].(string)
79+
if !ok {
80+
schedulerType = ""
81+
}
82+
83+
enableParameters, ok := preferencesMap["enable_parameters"].(string)
84+
if !ok {
85+
enableParameters = "negative_prompt,num_inference_steps"
86+
}
87+
88+
cuda := false
89+
if cudaVal, ok := preferencesMap["cuda"].(bool); ok {
90+
cuda = cudaVal
91+
}
92+
93+
modelConfig := config.ModelConfig{
94+
Name: name,
95+
Description: description,
96+
KnownUsecaseStrings: []string{"image"},
97+
Backend: backend,
98+
PredictionOptions: schema.PredictionOptions{
99+
BasicModelRequest: schema.BasicModelRequest{
100+
Model: details.URI,
101+
},
102+
},
103+
Diffusers: config.Diffusers{
104+
PipelineType: pipelineType,
105+
SchedulerType: schedulerType,
106+
EnableParameters: enableParameters,
107+
CUDA: cuda,
108+
},
109+
}
110+
111+
data, err := yaml.Marshal(modelConfig)
112+
if err != nil {
113+
return gallery.ModelConfig{}, err
114+
}
115+
116+
return gallery.ModelConfig{
117+
Name: name,
118+
Description: description,
119+
ConfigFile: string(data),
120+
}, nil
121+
}
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
package importers_test
2+
3+
import (
4+
"encoding/json"
5+
6+
"github.com/mudler/LocalAI/core/gallery/importers"
7+
. "github.com/mudler/LocalAI/core/gallery/importers"
8+
hfapi "github.com/mudler/LocalAI/pkg/huggingface-api"
9+
. "github.com/onsi/ginkgo/v2"
10+
. "github.com/onsi/gomega"
11+
)
12+
13+
var _ = Describe("DiffuserImporter", func() {
14+
var importer *DiffuserImporter
15+
16+
BeforeEach(func() {
17+
importer = &DiffuserImporter{}
18+
})
19+
20+
Context("Match", func() {
21+
It("should match when backend preference is diffusers", func() {
22+
preferences := json.RawMessage(`{"backend": "diffusers"}`)
23+
details := Details{
24+
URI: "https://example.com/model",
25+
Preferences: preferences,
26+
}
27+
28+
result := importer.Match(details)
29+
Expect(result).To(BeTrue())
30+
})
31+
32+
It("should match when HuggingFace details contain model_index.json", func() {
33+
hfDetails := &hfapi.ModelDetails{
34+
Files: []hfapi.ModelFile{
35+
{Path: "model_index.json"},
36+
},
37+
}
38+
details := Details{
39+
URI: "https://huggingface.co/test/model",
40+
HuggingFace: hfDetails,
41+
}
42+
43+
result := importer.Match(details)
44+
Expect(result).To(BeTrue())
45+
})
46+
47+
It("should match when HuggingFace details contain scheduler config", func() {
48+
hfDetails := &hfapi.ModelDetails{
49+
Files: []hfapi.ModelFile{
50+
{Path: "scheduler/scheduler_config.json"},
51+
},
52+
}
53+
details := Details{
54+
URI: "https://huggingface.co/test/model",
55+
HuggingFace: hfDetails,
56+
}
57+
58+
result := importer.Match(details)
59+
Expect(result).To(BeTrue())
60+
})
61+
62+
It("should not match when URI has no diffuser files and no backend preference", func() {
63+
details := Details{
64+
URI: "https://example.com/model.bin",
65+
}
66+
67+
result := importer.Match(details)
68+
Expect(result).To(BeFalse())
69+
})
70+
71+
It("should not match when backend preference is different", func() {
72+
preferences := json.RawMessage(`{"backend": "llama-cpp"}`)
73+
details := Details{
74+
URI: "https://example.com/model",
75+
Preferences: preferences,
76+
}
77+
78+
result := importer.Match(details)
79+
Expect(result).To(BeFalse())
80+
})
81+
82+
It("should return false when JSON preferences are invalid", func() {
83+
preferences := json.RawMessage(`invalid json`)
84+
details := Details{
85+
URI: "https://example.com/model",
86+
Preferences: preferences,
87+
}
88+
89+
result := importer.Match(details)
90+
Expect(result).To(BeFalse())
91+
})
92+
})
93+
94+
Context("Import", func() {
95+
It("should import model config with default name and description", func() {
96+
details := Details{
97+
URI: "https://huggingface.co/test/my-diffuser-model",
98+
}
99+
100+
modelConfig, err := importer.Import(details)
101+
102+
Expect(err).ToNot(HaveOccurred())
103+
Expect(modelConfig.Name).To(Equal("my-diffuser-model"))
104+
Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-diffuser-model"))
105+
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers"))
106+
Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-diffuser-model"))
107+
Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline"))
108+
Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: negative_prompt,num_inference_steps"))
109+
})
110+
111+
It("should import model config with custom name and description from preferences", func() {
112+
preferences := json.RawMessage(`{"name": "custom-diffuser", "description": "Custom diffuser model"}`)
113+
details := Details{
114+
URI: "https://huggingface.co/test/my-model",
115+
Preferences: preferences,
116+
}
117+
118+
modelConfig, err := importer.Import(details)
119+
120+
Expect(err).ToNot(HaveOccurred())
121+
Expect(modelConfig.Name).To(Equal("custom-diffuser"))
122+
Expect(modelConfig.Description).To(Equal("Custom diffuser model"))
123+
})
124+
125+
It("should use custom pipeline_type from preferences", func() {
126+
preferences := json.RawMessage(`{"pipeline_type": "StableDiffusion3Pipeline"}`)
127+
details := Details{
128+
URI: "https://huggingface.co/test/my-model",
129+
Preferences: preferences,
130+
}
131+
132+
modelConfig, err := importer.Import(details)
133+
134+
Expect(err).ToNot(HaveOccurred())
135+
Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusion3Pipeline"))
136+
})
137+
138+
It("should use default pipeline_type when not specified", func() {
139+
details := Details{
140+
URI: "https://huggingface.co/test/my-model",
141+
}
142+
143+
modelConfig, err := importer.Import(details)
144+
145+
Expect(err).ToNot(HaveOccurred())
146+
Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline"))
147+
})
148+
149+
It("should use custom scheduler_type from preferences", func() {
150+
preferences := json.RawMessage(`{"scheduler_type": "k_dpmpp_2m"}`)
151+
details := Details{
152+
URI: "https://huggingface.co/test/my-model",
153+
Preferences: preferences,
154+
}
155+
156+
modelConfig, err := importer.Import(details)
157+
158+
Expect(err).ToNot(HaveOccurred())
159+
Expect(modelConfig.ConfigFile).To(ContainSubstring("scheduler_type: k_dpmpp_2m"))
160+
})
161+
162+
It("should use cuda setting from preferences", func() {
163+
preferences := json.RawMessage(`{"cuda": true}`)
164+
details := Details{
165+
URI: "https://huggingface.co/test/my-model",
166+
Preferences: preferences,
167+
}
168+
169+
modelConfig, err := importer.Import(details)
170+
171+
Expect(err).ToNot(HaveOccurred())
172+
Expect(modelConfig.ConfigFile).To(ContainSubstring("cuda: true"))
173+
})
174+
175+
It("should use custom enable_parameters from preferences", func() {
176+
preferences := json.RawMessage(`{"enable_parameters": "num_inference_steps,guidance_scale"}`)
177+
details := Details{
178+
URI: "https://huggingface.co/test/my-model",
179+
Preferences: preferences,
180+
}
181+
182+
modelConfig, err := importer.Import(details)
183+
184+
Expect(err).ToNot(HaveOccurred())
185+
Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: num_inference_steps,guidance_scale"))
186+
})
187+
188+
It("should use custom backend from preferences", func() {
189+
preferences := json.RawMessage(`{"backend": "diffusers"}`)
190+
details := Details{
191+
URI: "https://huggingface.co/test/my-model",
192+
Preferences: preferences,
193+
}
194+
195+
modelConfig, err := importer.Import(details)
196+
197+
Expect(err).ToNot(HaveOccurred())
198+
Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers"))
199+
})
200+
201+
It("should handle invalid JSON preferences", func() {
202+
preferences := json.RawMessage(`invalid json`)
203+
details := Details{
204+
URI: "https://huggingface.co/test/my-model",
205+
Preferences: preferences,
206+
}
207+
208+
_, err := importer.Import(details)
209+
Expect(err).To(HaveOccurred())
210+
})
211+
212+
It("should extract filename correctly from URI with path", func() {
213+
details := importers.Details{
214+
URI: "https://huggingface.co/test/path/to/model",
215+
}
216+
217+
modelConfig, err := importer.Import(details)
218+
219+
Expect(err).ToNot(HaveOccurred())
220+
Expect(modelConfig.Name).To(Equal("model"))
221+
})
222+
223+
It("should include known_usecases as image in config", func() {
224+
details := Details{
225+
URI: "https://huggingface.co/test/my-model",
226+
}
227+
228+
modelConfig, err := importer.Import(details)
229+
230+
Expect(err).ToNot(HaveOccurred())
231+
Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:"))
232+
Expect(modelConfig.ConfigFile).To(ContainSubstring("- image"))
233+
})
234+
235+
It("should include diffusers configuration in config", func() {
236+
details := Details{
237+
URI: "https://huggingface.co/test/my-model",
238+
}
239+
240+
modelConfig, err := importer.Import(details)
241+
242+
Expect(err).ToNot(HaveOccurred())
243+
Expect(modelConfig.ConfigFile).To(ContainSubstring("diffusers:"))
244+
})
245+
})
246+
})

core/gallery/importers/importers.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ var defaultImporters = []Importer{
2020
&MLXImporter{},
2121
&VLLMImporter{},
2222
&TransformersImporter{},
23+
&DiffuserImporter{},
2324
}
2425

2526
type Details struct {

0 commit comments

Comments
 (0)