Skip to content

Commit a3e93b6

Browse files
committed
feat(inpainting): return DB record to frontend, add compatibility alias and add CUDA/error diagnostics in Python script
Signed-off-by: Greg <[email protected]>
1 parent 8aba078 commit a3e93b6

File tree

6 files changed

+341
-3
lines changed

6 files changed

+341
-3
lines changed

core/backend/image.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
4040

4141
return fn, nil
4242
}
43+
44+
// ImageGenerationFunc is a test-friendly indirection to call image generation logic.
45+
// Tests can override this variable to provide a stub implementation.
46+
var ImageGenerationFunc = ImageGeneration
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package openai
2+
3+
import (
4+
"encoding/base64"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"os"
10+
"path/filepath"
11+
"strconv"
12+
"time"
13+
14+
"github.com/google/uuid"
15+
"github.com/labstack/echo/v4"
16+
"github.com/rs/zerolog/log"
17+
18+
"github.com/mudler/LocalAI/core/config"
19+
"github.com/mudler/LocalAI/core/http/middleware"
20+
"github.com/mudler/LocalAI/core/schema"
21+
"github.com/mudler/LocalAI/core/backend"
22+
model "github.com/mudler/LocalAI/pkg/model"
23+
)
24+
25+
// InpaintingEndpoint handles POST /v1/images/inpainting
26+
//
27+
// Swagger / OpenAPI docstring (swaggo):
28+
// @Summary Image inpainting
29+
// @Description Perform image inpainting. Accepts multipart/form-data with `image` and `mask` files.
30+
// @Tags images
31+
// @Accept multipart/form-data
32+
// @Produce application/json
33+
// @Param model formData string true "Model identifier"
34+
// @Param prompt formData string true "Text prompt guiding the generation"
35+
// @Param steps formData int false "Number of inference steps (default 25)"
36+
// @Param image formData file true "Original image file"
37+
// @Param mask formData file true "Mask image file (white = area to inpaint)"
38+
// @Success 200 {object} schema.OpenAIResponse
39+
// @Failure 400 {object} map[string]string
40+
// @Failure 500 {object} map[string]string
41+
// @Router /v1/images/inpainting [post]
42+
func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc {
43+
return func(c echo.Context) error {
44+
// Parse basic form values
45+
modelName := c.FormValue("model")
46+
prompt := c.FormValue("prompt")
47+
stepsStr := c.FormValue("steps")
48+
49+
if modelName == "" || prompt == "" {
50+
log.Error().Msg("Inpainting Endpoint - missing model or prompt")
51+
return echo.ErrBadRequest
52+
}
53+
54+
// steps default
55+
steps := 25
56+
if stepsStr != "" {
57+
if v, err := strconv.Atoi(stepsStr); err == nil {
58+
steps = v
59+
}
60+
}
61+
62+
// Get uploaded files
63+
imageFile, err := c.FormFile("image")
64+
if err != nil {
65+
log.Error().Err(err).Msg("Inpainting Endpoint - missing image file")
66+
return echo.NewHTTPError(http.StatusBadRequest, "missing image file")
67+
}
68+
maskFile, err := c.FormFile("mask")
69+
if err != nil {
70+
log.Error().Err(err).Msg("Inpainting Endpoint - missing mask file")
71+
return echo.NewHTTPError(http.StatusBadRequest, "missing mask file")
72+
}
73+
74+
// Read files into memory (small files expected)
75+
imgSrc, err := imageFile.Open()
76+
if err != nil {
77+
return err
78+
}
79+
defer imgSrc.Close()
80+
imgBytes, err := io.ReadAll(imgSrc)
81+
if err != nil {
82+
return err
83+
}
84+
85+
maskSrc, err := maskFile.Open()
86+
if err != nil {
87+
return err
88+
}
89+
defer maskSrc.Close()
90+
maskBytes, err := io.ReadAll(maskSrc)
91+
if err != nil {
92+
return err
93+
}
94+
95+
// Create JSON with base64 fields expected by backend
96+
b64Image := base64.StdEncoding.EncodeToString(imgBytes)
97+
b64Mask := base64.StdEncoding.EncodeToString(maskBytes)
98+
99+
// get model config from context (middleware set it)
100+
cfg, ok := c.Get("MODEL_CONFIG").(*config.ModelConfig)
101+
if !ok || cfg == nil {
102+
log.Error().Msg("Inpainting Endpoint - model config not found in context")
103+
return echo.ErrBadRequest
104+
}
105+
106+
tmpDir := appConfig.GeneratedContentDir
107+
id := uuid.New().String()
108+
jsonName := fmt.Sprintf("inpaint_%s.json", id)
109+
jsonPath := filepath.Join(tmpDir, jsonName)
110+
jsonFile := map[string]string{
111+
"image": b64Image,
112+
"mask_image": b64Mask,
113+
}
114+
jf, err := os.CreateTemp(tmpDir, "inpaint_")
115+
if err != nil {
116+
return err
117+
}
118+
// write JSON
119+
enc := json.NewEncoder(jf)
120+
if err := enc.Encode(jsonFile); err != nil {
121+
jf.Close()
122+
os.Remove(jf.Name())
123+
return err
124+
}
125+
jf.Close()
126+
// rename to desired name
127+
if err := os.Rename(jf.Name(), jsonPath); err != nil {
128+
os.Remove(jf.Name())
129+
return err
130+
}
131+
// prepare dst
132+
outTmp, err := os.CreateTemp(tmpDir, "out_")
133+
if err != nil {
134+
os.Remove(jsonPath)
135+
return err
136+
}
137+
outTmp.Close()
138+
dst := outTmp.Name() + ".png"
139+
if err := os.Rename(outTmp.Name(), dst); err != nil {
140+
os.Remove(jsonPath)
141+
return err
142+
}
143+
144+
// Determine width/height default
145+
width := 512
146+
height := 512
147+
148+
// Call backend image generation via indirection so tests can stub it
149+
// Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON
150+
fn, err := backend.ImageGenerationFunc(height, width, 0, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, nil)
151+
if err != nil {
152+
os.Remove(jsonPath)
153+
return err
154+
}
155+
156+
// Execute generation function (blocking)
157+
if err := fn(); err != nil {
158+
os.Remove(jsonPath)
159+
os.Remove(dst)
160+
return err
161+
}
162+
163+
// On success, build response URL using BaseURL middleware helper
164+
baseURL := middleware.BaseURL(c)
165+
166+
// Return response
167+
created := int(time.Now().Unix())
168+
resp := &schema.OpenAIResponse{
169+
ID: id,
170+
Created: created,
171+
Data: []schema.Item{{
172+
URL: fmt.Sprintf("%sgenerated-images/%s", baseURL, filepath.Base(dst)),
173+
}},
174+
}
175+
176+
// cleanup json
177+
defer os.Remove(jsonPath)
178+
179+
return c.JSON(http.StatusOK, resp)
180+
}
181+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package openai
2+
3+
import (
4+
"bytes"
5+
"mime/multipart"
6+
"net/http"
7+
"net/http/httptest"
8+
"os"
9+
"path/filepath"
10+
"testing"
11+
12+
"github.com/labstack/echo/v4"
13+
"github.com/mudler/LocalAI/core/config"
14+
"github.com/mudler/LocalAI/core/backend"
15+
model "github.com/mudler/LocalAI/pkg/model"
16+
"github.com/stretchr/testify/require"
17+
)
18+
19+
func makeMultipartRequest(t *testing.T, fields map[string]string, files map[string][]byte) (*http.Request, string) {
20+
b := &bytes.Buffer{}
21+
w := multipart.NewWriter(b)
22+
for k, v := range fields {
23+
_ = w.WriteField(k, v)
24+
}
25+
for fname, content := range files {
26+
fw, err := w.CreateFormFile(fname, fname+".png")
27+
require.NoError(t, err)
28+
_, err = fw.Write(content)
29+
require.NoError(t, err)
30+
}
31+
require.NoError(t, w.Close())
32+
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b)
33+
req.Header.Set("Content-Type", w.FormDataContentType())
34+
return req, w.FormDataContentType()
35+
}
36+
37+
func TestInpainting_MissingFiles(t *testing.T) {
38+
e := echo.New()
39+
// handler requires cl, ml, appConfig but this test verifies missing files early
40+
h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig())
41+
42+
req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil)
43+
rec := httptest.NewRecorder()
44+
c := e.NewContext(req, rec)
45+
46+
err := h(c)
47+
require.Error(t, err)
48+
}
49+
50+
func TestInpainting_HappyPath(t *testing.T) {
51+
// Setup temp generated content dir
52+
tmpDir, err := os.MkdirTemp("", "gencontent")
53+
require.NoError(t, err)
54+
defer os.RemoveAll(tmpDir)
55+
56+
appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir))
57+
58+
// stub the backend.ImageGenerationFunc
59+
orig := backend.ImageGenerationFunc
60+
backend.ImageGenerationFunc = func(height, width, mode, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {
61+
fn := func() error {
62+
// write a fake png file to dst
63+
return os.WriteFile(dst, []byte("PNGDATA"), 0644)
64+
}
65+
return fn, nil
66+
}
67+
defer func() { backend.ImageGenerationFunc = orig }()
68+
69+
// prepare multipart request with image and mask
70+
fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"}
71+
files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")}
72+
reqBuf, _ := makeMultipartRequest(t, fields, files)
73+
74+
rec := httptest.NewRecorder()
75+
e := echo.New()
76+
c := e.NewContext(reqBuf, rec)
77+
78+
// set a minimal model config in context as handler expects
79+
c.Set("MODEL_CONFIG", &config.ModelConfig{Backend: "diffusers"})
80+
81+
h := InpaintingEndpoint(nil, nil, appConf)
82+
83+
// call handler
84+
err = h(c)
85+
require.NoError(t, err)
86+
require.Equal(t, http.StatusOK, rec.Code)
87+
88+
// verify response body contains generated-images path
89+
body := rec.Body.String()
90+
require.Contains(t, body, "generated-images")
91+
92+
// confirm the file was created in tmpDir
93+
// parse out filename from response (naive search)
94+
// find "generated-images/" and extract until closing quote or brace
95+
idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/"))
96+
require.True(t, idx >= 0)
97+
rest := rec.Body.Bytes()[idx:]
98+
end := bytes.IndexAny(rest, "\",}\n")
99+
if end == -1 {
100+
end = len(rest)
101+
}
102+
fname := string(rest[len("generated-images/"):end])
103+
// ensure file exists
104+
_, err = os.Stat(filepath.Join(tmpDir, fname))
105+
require.NoError(t, err)
106+
}

core/http/endpoints/openai/mcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,11 @@ func MCPCompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader,
108108
log.Debug().Msgf("[model agent] [model: %s] Reasoning: %s", config.Name, s)
109109
}),
110110
cogito.WithToolCallBack(func(t *cogito.ToolChoice) bool {
111-
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", t.Name, t.Reasoning, t.Arguments)
111+
log.Debug().Msgf("[model agent] [model: %s] Tool call: %s, reasoning: %s, arguments: %+v", config.Name, t.Name, t.Reasoning, t.Arguments)
112112
return true
113113
}),
114114
cogito.WithToolCallResultCallback(func(t cogito.ToolStatus) {
115-
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", t.Name, t.Result, t.ToolArguments)
115+
log.Debug().Msgf("[model agent] [model: %s] Tool call result: %s, tool arguments: %+v", config.Name, t.Result, t.ToolArguments)
116116
}),
117117
)
118118

core/http/routes/openai.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ func RegisterOpenAIRoutes(app *echo.Echo,
140140
// images
141141
imageHandler := openai.ImageEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
142142
imageMiddleware := []echo.MiddlewareFunc{
143-
re.BuildConstantDefaultModelNameMiddleware("stablediffusion"),
143+
// Par défaut, utiliser le modèle d'inpainting souhaité pour les endpoints images/inpainting
144+
re.BuildConstantDefaultModelNameMiddleware("dreamshaper-8-inpainting"),
144145
re.SetModelAndConfig(func() schema.LocalAIRequest { return new(schema.OpenAIRequest) }),
145146
func(next echo.HandlerFunc) echo.HandlerFunc {
146147
return func(c echo.Context) error {
@@ -155,6 +156,11 @@ func RegisterOpenAIRoutes(app *echo.Echo,
155156
app.POST("/v1/images/generations", imageHandler, imageMiddleware...)
156157
app.POST("/images/generations", imageHandler, imageMiddleware...)
157158

159+
// inpainting endpoint (image + mask) - reuse same middleware config as images
160+
inpaintingHandler := openai.InpaintingEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
161+
app.POST("/v1/images/inpainting", inpaintingHandler, imageMiddleware...)
162+
app.POST("/images/inpainting", inpaintingHandler, imageMiddleware...)
163+
158164
// videos (OpenAI-compatible endpoints mapped to LocalAI video handler)
159165
videoHandler := openai.VideoEndpoint(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig())
160166
videoMiddleware := []echo.MiddlewareFunc{

swagger/swagger.yaml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,47 @@ paths:
11971197
schema:
11981198
$ref: '#/definitions/schema.OpenAIResponse'
11991199
summary: Creates an image given a prompt.
1200+
/v1/images/inpainting:
1201+
post:
1202+
consumes:
1203+
- multipart/form-data
1204+
- application/json
1205+
parameters:
1206+
- in: formData
1207+
name: model
1208+
type: string
1209+
description: Model name (eg. dreamshaper-8-inpainting)
1210+
required: true
1211+
- in: formData
1212+
name: prompt
1213+
type: string
1214+
description: Positive prompt text
1215+
required: true
1216+
- in: formData
1217+
name: image
1218+
type: file
1219+
description: Source image (PNG/JPEG)
1220+
required: true
1221+
- in: formData
1222+
name: mask
1223+
type: file
1224+
description: Mask image (PNG). White=keep, Black=replace (or as backend expects)
1225+
required: true
1226+
- in: formData
1227+
name: steps
1228+
type: integer
1229+
description: Number of inference steps
1230+
- in: body
1231+
name: request
1232+
description: "Alternative JSON payload with base64 fields: { image: '<b64>', mask: '<b64>', model, prompt }"
1233+
schema:
1234+
$ref: '#/definitions/schema.OpenAIRequest'
1235+
responses:
1236+
"200":
1237+
description: Successful inpainting
1238+
schema:
1239+
$ref: '#/definitions/schema.OpenAIResponse'
1240+
summary: Creates an inpainted image given an image + mask + prompt.
12001241
/v1/mcp/chat/completions:
12011242
post:
12021243
parameters:

0 commit comments

Comments
 (0)