1
1
package bind
2
2
3
3
import (
4
+ "errors"
5
+ "image"
6
+ "io"
4
7
"os"
5
8
"unsafe"
6
9
@@ -23,14 +26,7 @@ type cImage struct {
23
26
width uint32
24
27
height uint32
25
28
channel uint32
26
- data unsafe.Pointer
27
- }
28
-
29
- type cDarwinImage struct {
30
- width uint32
31
- height uint32
32
- channel uint32
33
- data * byte
29
+ data uintptr
34
30
}
35
31
36
32
type Image struct {
@@ -44,17 +40,17 @@ type CStableDiffusionImpl struct {
44
40
libSd uintptr
45
41
libFilename string
46
42
47
- txt2img func (ctx unsafe.Pointer , prompt string , negativePrompt string , clipSkip int , cfgScale float32 , width int , height int , sampleMethod int , sampleSteps int , seed int64 , batchCount int ) unsafe.Pointer
48
- sdGetSystemInfo func () unsafe.Pointer
43
+ txt2img func (ctx unsafe.Pointer , prompt string , negativePrompt string , clipSkip int , cfgScale float32 , width int , height int , sampleMethod int , sampleSteps int , seed int64 , batchCount int ) unsafe.Pointer
44
+ sdGetSystemInfo func () unsafe.Pointer
45
+ sdSetLogCallback func (callback func (level int , text unsafe.Pointer , data unsafe.Pointer ) unsafe.Pointer , data unsafe.Pointer )
49
46
50
- newSdCtx func (modelPath string , vaePath string , taesdPath string , loraModelDir string , vaeDecodeOnly bool , vaeTiling bool , freeParamsImmediately bool , nThreads int , wType int , rngType int , schedule int ) unsafe.Pointer
47
+ img2img func (ctx unsafe.Pointer , img uintptr , prompt string , negativePrompt string , clipSkip int , cfgScale float32 , width int , height int , sampleMethod int , sampleSteps int , strength float32 , seed int64 , batchCount int ) uintptr
48
+ upscale func (ctx unsafe.Pointer , img unsafe.Pointer , upscaleFactor uint32 ) uintptr
51
49
52
- sdSetLogCallback func (callback func (level int , text unsafe.Pointer , data unsafe.Pointer ) unsafe.Pointer , data unsafe.Pointer )
53
- img2img func (ctx unsafe.Pointer , img uintptr , prompt string , negativePrompt string , clipSkip int , cfgScale float32 , width int , height int , sampleMethod int , sampleSteps int , strength float32 , seed int64 , batchCount int ) uintptr
54
- freeSdCtx func (ctx unsafe.Pointer )
55
- newUpscalerCtx func (esrganPath string , nThreads int , wtype int ) unsafe.Pointer
56
- freeUpscalerCtx func (ctx unsafe.Pointer )
57
- upscale func (ctx unsafe.Pointer , img unsafe.Pointer , upscaleFactor uint32 ) uintptr
50
+ newSdCtx func (modelPath string , vaePath string , taesdPath string , loraModelDir string , vaeDecodeOnly bool , vaeTiling bool , freeParamsImmediately bool , nThreads int , wType int , rngType int , schedule int ) unsafe.Pointer
51
+ freeSdCtx func (ctx unsafe.Pointer )
52
+ newUpscalerCtx func (esrganPath string , nThreads int , wtype int ) unsafe.Pointer
53
+ freeUpscalerCtx func (ctx unsafe.Pointer )
58
54
}
59
55
60
56
func NewCStableDiffusion () (* CStableDiffusionImpl , error ) {
@@ -69,15 +65,16 @@ func NewCStableDiffusion() (*CStableDiffusionImpl, error) {
69
65
}
70
66
71
67
purego .RegisterLibFunc (& impl .txt2img , libSd , "txt2img" )
72
-
73
68
purego .RegisterLibFunc (& impl .sdGetSystemInfo , libSd , "sd_get_system_info" )
74
- purego .RegisterLibFunc (& impl .newSdCtx , libSd , "new_sd_ctx" )
75
69
purego .RegisterLibFunc (& impl .sdSetLogCallback , libSd , "sd_set_log_callback" )
70
+
76
71
purego .RegisterLibFunc (& impl .img2img , libSd , "img2img" )
72
+ purego .RegisterLibFunc (& impl .upscale , libSd , "upscale" )
73
+
74
+ purego .RegisterLibFunc (& impl .newSdCtx , libSd , "new_sd_ctx" )
77
75
purego .RegisterLibFunc (& impl .freeSdCtx , libSd , "free_sd_ctx" )
78
76
purego .RegisterLibFunc (& impl .newUpscalerCtx , libSd , "new_upscaler_ctx" )
79
77
purego .RegisterLibFunc (& impl .freeUpscalerCtx , libSd , "free_upscaler_ctx" )
80
- purego .RegisterLibFunc (& impl .upscale , libSd , "upscale" )
81
78
82
79
return & impl , err
83
80
}
@@ -129,52 +126,38 @@ func (c *CStableDiffusionImpl) GetSystemInfo() string {
129
126
return goString (c .sdGetSystemInfo ())
130
127
}
131
128
132
- /*
133
-
134
- func (c *CStableDiffusionImpl) NewUpscalerCtx(esrganPath string, nThreads int, wType opts.WType) *CUpScalerCtx {
135
- ctx := c.newUpscalerCtx(esrganPath, nThreads, int(wType))
136
-
137
- return &CUpScalerCtx{ctx: ctx}
138
- }
139
-
140
- func (c *CStableDiffusionImpl) FreeUpscalerCtx(ctx *CUpScalerCtx) {
141
- ptr := *(*unsafe.Pointer)(unsafe.Pointer(&ctx.ctx))
142
- if ptr != nil {
143
- c.freeUpscalerCtx(ctx.ctx)
129
+ func (c * CStableDiffusionImpl ) UpscaleImage (ctx * CUpScalerCtx , reader io.Reader , upscaleFactor uint32 ) (result Image , err error ) {
130
+ decode , _ , err := image .Decode (reader )
131
+ if err != nil {
132
+ return
144
133
}
145
- ctx = nil
146
- runtime.GC()
147
- }
148
-
149
- func (c *CStableDiffusionImpl) UpscaleImage(ctx *CUpScalerCtx, image image.Image, upscaleFactor uint32) Image {
150
- // img := imageToBytes(image)
151
- //
152
- // var ci = cImage{
153
- // width: img.Width,
154
- // height: img.Height,
155
- // channel: img.Channel,
156
- // data: unsafe.Pointer(&img.Data[0]),
157
- // }
158
134
159
- println("TEAPOT 1" )
135
+ var img = imageToBytes ( decode )
160
136
161
- uPtr := c.upscale(ctx.ctx, nil, upscaleFactor)
137
+ var ci = & cImage {
138
+ width : img .Width ,
139
+ height : img .Height ,
140
+ channel : img .Channel ,
141
+ data : uintptr (unsafe .Pointer (& img .Data [0 ])),
142
+ }
162
143
163
- println("TEAPOT 2" )
144
+ uPtr := c . upscale ( ctx . ctx , unsafe . Pointer ( & ci ), upscaleFactor )
164
145
165
146
ptr := * (* unsafe .Pointer )(unsafe .Pointer (& uPtr ))
166
147
if ptr == nil {
167
- return Image{}
148
+ err = errors .New ("nil pointer" )
149
+ return
168
150
}
169
- println("TEAPOT 3")
170
151
171
152
cimg := (* cImage )(ptr )
172
153
dataPtr := * (* unsafe .Pointer )(unsafe .Pointer (& cimg .data ))
173
- return Image{
154
+
155
+ result = Image {
174
156
Width : cimg .width ,
175
157
Height : cimg .height ,
176
158
Channel : cimg .channel ,
177
159
Data : unsafe .Slice ((* byte )(dataPtr ), cimg .channel * cimg .width * cimg .height ),
178
160
}
161
+
162
+ return
179
163
}
180
- */
0 commit comments