Skip to content

Commit 5c37c51

Browse files
committed
Work
1 parent c3f7bcc commit 5c37c51

File tree

5 files changed

+69
-78
lines changed

5 files changed

+69
-78
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
Based on
55
- [seasonjs/stable-diffusion](https://github.com/seasonjs/stable-diffusion)
66
- [seasonjs/stable-diffusion.cpp-build](https://github.com/seasonjs/stable-diffusion.cpp-build)
7+
- [leejet/stable-diffusion.cpp](https://github.com/leejet/stable-diffusion.cpp)
8+
9+

main.go

+15-13
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@ package main
33
import (
44
"os"
55

6+
"github.com/davecgh/go-spew/spew"
7+
68
"github.com/ring-c/go-web-diff/pkg/opts"
79
"github.com/ring-c/go-web-diff/pkg/sd"
810
)
911

1012
func main() {
13+
spew.Config.Indent = "\t"
14+
1115
options := sd.DefaultOptions
1216
options.GpuEnable = true
1317
options.Wtype = opts.F16
@@ -23,28 +27,28 @@ func main() {
2327
_ = model.Close()
2428
}()
2529

26-
err = model.LoadFromFile("/media/ed/files/sd/models/Stable-diffusion/dreamshaperXL_v21TurboDPMSDE.safetensors")
30+
// println(model.GetSystemInfo())
31+
32+
// err = generate(model)
2733
if err != nil {
2834
println(err.Error())
2935
return
3036
}
3137

32-
// println(model.GetSystemInfo())
33-
34-
err = generate(model)
38+
err = upscale(model)
3539
if err != nil {
3640
println(err.Error())
3741
return
3842
}
39-
40-
// err = upscale(model)
41-
// if err != nil {
42-
// println(err.Error())
43-
// return
44-
// }
4543
}
4644

4745
func generate(model *sd.Model) (err error) {
46+
err = model.LoadFromFile("/media/ed/files/sd/models/Stable-diffusion/dreamshaperXL_v21TurboDPMSDE.safetensors")
47+
if err != nil {
48+
println(err.Error())
49+
return
50+
}
51+
4852
var file *os.File
4953
file, err = os.Create("./output/0.png")
5054
if err != nil {
@@ -60,7 +64,7 @@ func generate(model *sd.Model) (err error) {
6064
params.Width = 1024
6165
params.Height = 1024
6266
params.CfgScale = 2
63-
params.SampleSteps = 4
67+
params.SampleSteps = 32
6468
params.SampleMethod = opts.EULER_A
6569
params.Seed = 4242
6670

@@ -72,7 +76,6 @@ func generate(model *sd.Model) (err error) {
7276
return
7377
}
7478

75-
/*
7679
func upscale(model *sd.Model) (err error) {
7780
fileRead, err := os.Open("./output/0.png")
7881
if err != nil {
@@ -99,4 +102,3 @@ func upscale(model *sd.Model) (err error) {
99102

100103
return
101104
}
102-
*/

pkg/bind/binding.go

+35-52
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package bind
22

33
import (
4+
"errors"
5+
"image"
6+
"io"
47
"os"
58
"unsafe"
69

@@ -23,14 +26,7 @@ type cImage struct {
2326
width uint32
2427
height uint32
2528
channel uint32
26-
data unsafe.Pointer
27-
}
28-
29-
type cDarwinImage struct {
30-
width uint32
31-
height uint32
32-
channel uint32
33-
data *byte
29+
data uintptr
3430
}
3531

3632
type Image struct {
@@ -44,17 +40,17 @@ type CStableDiffusionImpl struct {
4440
libSd uintptr
4541
libFilename string
4642

47-
txt2img func(ctx unsafe.Pointer, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, seed int64, batchCount int) unsafe.Pointer
48-
sdGetSystemInfo func() unsafe.Pointer
43+
txt2img func(ctx unsafe.Pointer, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, seed int64, batchCount int) unsafe.Pointer
44+
sdGetSystemInfo func() unsafe.Pointer
45+
sdSetLogCallback func(callback func(level int, text unsafe.Pointer, data unsafe.Pointer) unsafe.Pointer, data unsafe.Pointer)
4946

50-
newSdCtx func(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType int, rngType int, schedule int) unsafe.Pointer
47+
img2img func(ctx unsafe.Pointer, img uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, strength float32, seed int64, batchCount int) uintptr
48+
upscale func(ctx unsafe.Pointer, img unsafe.Pointer, upscaleFactor uint32) uintptr
5149

52-
sdSetLogCallback func(callback func(level int, text unsafe.Pointer, data unsafe.Pointer) unsafe.Pointer, data unsafe.Pointer)
53-
img2img func(ctx unsafe.Pointer, img uintptr, prompt string, negativePrompt string, clipSkip int, cfgScale float32, width int, height int, sampleMethod int, sampleSteps int, strength float32, seed int64, batchCount int) uintptr
54-
freeSdCtx func(ctx unsafe.Pointer)
55-
newUpscalerCtx func(esrganPath string, nThreads int, wtype int) unsafe.Pointer
56-
freeUpscalerCtx func(ctx unsafe.Pointer)
57-
upscale func(ctx unsafe.Pointer, img unsafe.Pointer, upscaleFactor uint32) uintptr
50+
newSdCtx func(modelPath string, vaePath string, taesdPath string, loraModelDir string, vaeDecodeOnly bool, vaeTiling bool, freeParamsImmediately bool, nThreads int, wType int, rngType int, schedule int) unsafe.Pointer
51+
freeSdCtx func(ctx unsafe.Pointer)
52+
newUpscalerCtx func(esrganPath string, nThreads int, wtype int) unsafe.Pointer
53+
freeUpscalerCtx func(ctx unsafe.Pointer)
5854
}
5955

6056
func NewCStableDiffusion() (*CStableDiffusionImpl, error) {
@@ -69,15 +65,16 @@ func NewCStableDiffusion() (*CStableDiffusionImpl, error) {
6965
}
7066

7167
purego.RegisterLibFunc(&impl.txt2img, libSd, "txt2img")
72-
7368
purego.RegisterLibFunc(&impl.sdGetSystemInfo, libSd, "sd_get_system_info")
74-
purego.RegisterLibFunc(&impl.newSdCtx, libSd, "new_sd_ctx")
7569
purego.RegisterLibFunc(&impl.sdSetLogCallback, libSd, "sd_set_log_callback")
70+
7671
purego.RegisterLibFunc(&impl.img2img, libSd, "img2img")
72+
purego.RegisterLibFunc(&impl.upscale, libSd, "upscale")
73+
74+
purego.RegisterLibFunc(&impl.newSdCtx, libSd, "new_sd_ctx")
7775
purego.RegisterLibFunc(&impl.freeSdCtx, libSd, "free_sd_ctx")
7876
purego.RegisterLibFunc(&impl.newUpscalerCtx, libSd, "new_upscaler_ctx")
7977
purego.RegisterLibFunc(&impl.freeUpscalerCtx, libSd, "free_upscaler_ctx")
80-
purego.RegisterLibFunc(&impl.upscale, libSd, "upscale")
8178

8279
return &impl, err
8380
}
@@ -129,52 +126,38 @@ func (c *CStableDiffusionImpl) GetSystemInfo() string {
129126
return goString(c.sdGetSystemInfo())
130127
}
131128

132-
/*
133-
134-
func (c *CStableDiffusionImpl) NewUpscalerCtx(esrganPath string, nThreads int, wType opts.WType) *CUpScalerCtx {
135-
ctx := c.newUpscalerCtx(esrganPath, nThreads, int(wType))
136-
137-
return &CUpScalerCtx{ctx: ctx}
138-
}
139-
140-
func (c *CStableDiffusionImpl) FreeUpscalerCtx(ctx *CUpScalerCtx) {
141-
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
142-
if ptr != nil {
143-
c.freeUpscalerCtx(ctx.ctx)
129+
func (c *CStableDiffusionImpl) UpscaleImage(ctx *CUpScalerCtx, reader io.Reader, upscaleFactor uint32) (result Image, err error) {
130+
decode, _, err := image.Decode(reader)
131+
if err != nil {
132+
return
144133
}
145-
ctx = nil
146-
runtime.GC()
147-
}
148-
149-
func (c *CStableDiffusionImpl) UpscaleImage(ctx *CUpScalerCtx, image image.Image, upscaleFactor uint32) Image {
150-
// img := imageToBytes(image)
151-
//
152-
// var ci = cImage{
153-
// width: img.Width,
154-
// height: img.Height,
155-
// channel: img.Channel,
156-
// data: unsafe.Pointer(&img.Data[0]),
157-
// }
158134

159-
println("TEAPOT 1")
135+
var img = imageToBytes(decode)
160136

161-
uPtr := c.upscale(ctx.ctx, nil, upscaleFactor)
137+
var ci = &cImage{
138+
width: img.Width,
139+
height: img.Height,
140+
channel: img.Channel,
141+
data: uintptr(unsafe.Pointer(&img.Data[0])),
142+
}
162143

163-
println("TEAPOT 2")
144+
uPtr := c.upscale(ctx.ctx, unsafe.Pointer(&ci), upscaleFactor)
164145

165146
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&uPtr))
166147
if ptr == nil {
167-
return Image{}
148+
err = errors.New("nil pointer")
149+
return
168150
}
169-
println("TEAPOT 3")
170151

171152
cimg := (*cImage)(ptr)
172153
dataPtr := *(*unsafe.Pointer)(unsafe.Pointer(&cimg.data))
173-
return Image{
154+
155+
result = Image{
174156
Width: cimg.width,
175157
Height: cimg.height,
176158
Channel: cimg.channel,
177159
Data: unsafe.Slice((*byte)(dataPtr), cimg.channel*cimg.width*cimg.height),
178160
}
161+
162+
return
179163
}
180-
*/

pkg/bind/context.go

+15
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,18 @@ func (c *CStableDiffusionImpl) FreeCtx(ctx *CStableDiffusionCtx) {
2222
ctx = nil
2323
runtime.GC()
2424
}
25+
26+
func (c *CStableDiffusionImpl) NewUpscalerCtx(esrganPath string, nThreads int, wType opts.WType) *CUpScalerCtx {
27+
ctx := c.newUpscalerCtx(esrganPath, nThreads, int(wType))
28+
29+
return &CUpScalerCtx{ctx: ctx}
30+
}
31+
32+
func (c *CStableDiffusionImpl) FreeUpscalerCtx(ctx *CUpScalerCtx) {
33+
ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
34+
if ptr != nil {
35+
c.freeUpscalerCtx(ctx.ctx)
36+
}
37+
ctx = nil
38+
runtime.GC()
39+
}

pkg/sd/model.go

+1-13
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ func imageToWriter(image image.Image, imageType opts.OutputsImageType, writer io
186186
return
187187
}
188188

189-
/*
190189
func (sd *Model) UpscaleImage(reader io.Reader, esrganPath string, upscaleFactor uint32, writer io.Writer) (err error) {
191190
if sd.upscalerCtx == nil {
192191
sd.esrganPath = esrganPath
@@ -200,28 +199,17 @@ func (sd *Model) UpscaleImage(reader io.Reader, esrganPath string, upscaleFactor
200199
sd.upscalerCtx = sd.cSD.NewUpscalerCtx(esrganPath, sd.options.Threads, sd.options.Wtype)
201200
}
202201

203-
decode, _, err := image.Decode(reader)
202+
img, err := sd.cSD.UpscaleImage(sd.upscalerCtx, reader, upscaleFactor)
204203
if err != nil {
205204
return
206205
}
207206

208-
println("UPSCALE")
209-
210-
img := sd.cSD.UpscaleImage(sd.upscalerCtx, decode, upscaleFactor)
211-
212-
spew.Dump(img)
213-
214-
println("BYTES")
215-
216207
outputsImage := bytesToImage(img.Data, int(img.Width), int(img.Height))
217208

218-
println("WRITE")
219-
220209
err = imageToWriter(outputsImage, opts.PNG, writer)
221210
if err != nil {
222211
return
223212
}
224213

225214
return
226215
}
227-
*/

0 commit comments

Comments
 (0)