Skip to content

Commit 7b1c00a

Browse files
committed
Work
Former-commit-id: f43b807
1 parent 48e43ac commit 7b1c00a

File tree

8 files changed

+25
-99
lines changed

8 files changed

+25
-99
lines changed

pkg/bind/binding.go

-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ func NewCStableDiffusion() (*CStableDiffusionImpl, error) {
5353
libFilename: filename,
5454
}
5555

56-
// purego.RegisterLibFunc(&impl.img2img, libSd, "img2img")
57-
5856
purego.RegisterLibFunc(&impl.newSDContext, libSd, "new_sd_ctx_go")
5957
purego.RegisterLibFunc(&impl.freeSDContext, libSd, "free_sd_ctx")
6058

Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
0daa5df21052eb6fb6e600a58dd5f3a7d979ad87
1+
9894f441c6cec8a914fef907c4ce946c40aa2e12

pkg/ggml/struct.go

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ type Struct struct {
1818
TensorClamp func(tx unsafe.Pointer, min float32, max float32)
1919

2020
TensorGetF32 func(tx unsafe.Pointer, l, k, j, i int) float32
21+
// VectorToGgmlTensorI32 func(workCtx unsafe.Pointer, vector unsafe.Pointer) unsafe.Pointer
2122
}
2223

2324
type InitParams struct {

pkg/txt2img/GetLearnedCondition.go

+8-72
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package txt2img
44

55
import (
66
"math"
7-
"time"
87
"unsafe"
98
)
109

@@ -14,82 +13,19 @@ type ggmlTensor struct {
1413
}
1514

1615
var (
17-
n_threads = 4
18-
diffusion_model = &DiffusionModel{}
16+
nThreads = 4
1917
)
2018

21-
func ggmlTimeMs() int64 {
22-
return time.Now().UnixNano() / int64(time.Millisecond)
23-
}
24-
25-
func vectorToGgmlTensorI32(workCtx unsafe.Pointer, vec []int) *ggmlTensor {
26-
panic("fix me")
27-
return &ggmlTensor{}
28-
}
29-
30-
func ggmlDupTensor(workCtx unsafe.Pointer, tensor *ggmlTensor) *ggmlTensor {
31-
panic("fix me")
32-
return &ggmlTensor{}
33-
}
34-
35-
func ggmlTensorMean(tensor *ggmlTensor) float32 {
36-
panic("fix me")
37-
return 0.0
38-
}
39-
40-
func ggmlTensorGetF32(tensor *ggmlTensor, i0, i1, i2 int) float32 {
41-
panic("fix me")
42-
return 0.0
43-
}
44-
45-
func ggmlTensorSetF32(tensor *ggmlTensor, value float32, i0, i1, i2 int) {
46-
panic("fix me")
47-
}
48-
49-
func ggmlTensorScale(tensor *ggmlTensor, scale float32) {
50-
panic("fix me")
51-
}
52-
53-
func ggmlNelements(tensor *ggmlTensor) int {
54-
panic("fix me")
55-
return 0
56-
}
57-
58-
func vectorToGgmlTensor(workCtx unsafe.Pointer, vec []float32) *ggmlTensor {
59-
panic("fix me")
60-
return &ggmlTensor{}
61-
}
62-
63-
func ggmlReshape2D(workCtx unsafe.Pointer, tensor *ggmlTensor, dim0, dim1 int) *ggmlTensor {
64-
panic("fix me")
65-
return &ggmlTensor{}
66-
}
67-
68-
func ggmlNewTensor1D(workCtx unsafe.Pointer, tensorType int, size int) *ggmlTensor {
69-
panic("fix me")
70-
return &ggmlTensor{}
71-
}
72-
73-
func ggmlView2D(workCtx unsafe.Pointer, tensor *ggmlTensor, dim0, dim1, stride, offset int) *ggmlTensor {
74-
panic("fix me")
75-
return &ggmlTensor{}
76-
}
77-
78-
func ggmlNbytes(tensor *ggmlTensor) int {
79-
panic("fix me")
80-
return 0
81-
}
82-
8319
func setTimestepEmbedding(timesteps []float32, embedView *ggmlTensor, outDim int) {
8420
panic("fix me")
8521
}
8622

87-
func getLearnedCondition(workCtx unsafe.Pointer, text string, clipSkip, width, height int, forceZeroEmbeddings bool) (*ggmlTensor, *ggmlTensor) {
88-
tokens, weights := tokenize(text, true)
89-
return getLearnedConditionCommon(workCtx, tokens, weights, clipSkip, width, height, forceZeroEmbeddings)
23+
func (gen *Generator) GetLearnedCondition(workCtx unsafe.Pointer, prompt string, clipSkip, width, height int, forceZeroEmbeddings bool) (unsafe.Pointer, unsafe.Pointer) {
24+
tokens, weights := gen.Tokenize(prompt, true)
25+
return gen.getLearnedConditionCommon(workCtx, tokens, weights, clipSkip, width, height, forceZeroEmbeddings)
9026
}
9127

92-
func getLearnedConditionCommon(workCtx unsafe.Pointer, tokens []int, weights []float32, clipSkip, width, height int, forceZeroEmbeddings bool) (*ggmlTensor, *ggmlTensor) {
28+
func (gen *Generator) getLearnedConditionCommon(workCtx unsafe.Pointer, tokens []int, weights []float32, clipSkip, width, height int, forceZeroEmbeddings bool) (unsafe.Pointer, unsafe.Pointer) {
9329
cond_stage_model.setClipSkip(clipSkip)
9430
var hiddenStates, chunkHiddenStates, pooled *ggmlTensor
9531
var hiddenStatesVec []float32
@@ -100,7 +36,7 @@ func getLearnedConditionCommon(workCtx unsafe.Pointer, tokens []int, weights []f
10036
chunkTokens := tokens[chunkIdx*chunkLen : (chunkIdx+1)*chunkLen]
10137
chunkWeights := weights[chunkIdx*chunkLen : (chunkIdx+1)*chunkLen]
10238

103-
inputIds := vectorToGgmlTensorI32(workCtx, chunkTokens)
39+
inputIds := gen.GGML.VectorToGgmlTensorI32(workCtx, chunkTokens)
10440
var inputIds2 *ggmlTensor
10541
var maxTokenIdx int
10642
// if version == VERSION_XL {
@@ -119,9 +55,9 @@ func getLearnedConditionCommon(workCtx unsafe.Pointer, tokens []int, weights []f
11955
inputIds2 = vectorToGgmlTensorI32(workCtx, chunkTokens)
12056
// }
12157

122-
cond_stage_model.compute(n_threads, inputIds, inputIds2, maxTokenIdx, false, &chunkHiddenStates, workCtx)
58+
cond_stage_model.compute(nThreads, inputIds, inputIds2, maxTokenIdx, false, &chunkHiddenStates, workCtx)
12359
if chunkIdx == 0 {
124-
cond_stage_model.compute(n_threads, inputIds, inputIds2, maxTokenIdx, true, &pooled, workCtx)
60+
cond_stage_model.compute(nThreads, inputIds, inputIds2, maxTokenIdx, true, &pooled, workCtx)
12561
}
12662

12763
result := ggmlDupTensor(workCtx, chunkHiddenStates)

pkg/txt2img/SampleGo.go

-7
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@ type DiffusionModel struct {
1919
}
2020
}
2121

22-
const (
23-
UNK_TOKEN_ID int = 49407
24-
BOS_TOKEN_ID int = 49406
25-
EOS_TOKEN_ID int = 49407
26-
PAD_TOKEN_ID int = 49407
27-
)
28-
2922
func (s *DiffusionModel) FreeParamsBuffer() {
3023
// Implement FreeParamsBuffer method
3124
}

pkg/txt2img/generate.go

+1-9
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,7 @@ func (gen *Generator) Generate(in *opts.Options) (filenames []string, err error)
7878
xT := gen.GGML.NewTensor4D(workCtx, 0, W, H, 4, 1)
7979
gen.GGML.TensorSetF32Rand(xT, seed)
8080

81-
/*
82-
var pairCond = gen.GetLearnedCondition(gen.Model.GetCTX(), workCtx, in.Prompt, in.Width, in.Height, in.ClipSkip)
83-
84-
var c = gen.PairGet(pairCond, true)
85-
var cVector = gen.PairGet(pairCond, false)
86-
87-
spew.Dump(c)
88-
spew.Dump(cVector)
89-
*/
81+
// gen.SetLearnedCondition(gen.Model.GetCTX(), workCtx, in.Prompt, in.Width, in.Height, in.ClipSkip)
9082

9183
if in.Debug {
9284
fmt.Printf("[%d/%d] Prep done in %gs\n", i+1, in.BatchCount, time.Now().Sub(timeStart).Seconds())

pkg/txt2img/init.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ type Generator struct {
2121
fileWrite sync.WaitGroup
2222
filenames []string
2323

24-
GetLearnedCondition func(sdCTX, ggmlCTX unsafe.Pointer, prompt string, width, height, clipSkip int) unsafe.Pointer // pair
25-
PairGet func(pair unsafe.Pointer, first bool) unsafe.Pointer // ggml_tensor
26-
GoSample func(sdCTX, ggmlCTX, xT unsafe.Pointer, prompt string, sigmasCnt int, sigmas []float32) unsafe.Pointer // ggml_tensor
24+
// SetLearnedCondition func(sdCTX, ggmlCTX unsafe.Pointer, prompt string, width, height, clipSkip int) // pair
25+
GoSample func(sdCTX, ggmlCTX, xT unsafe.Pointer, prompt string, sigmasCnt int, sigmas []float32) unsafe.Pointer // ggml_tensor
2726

2827
DecodeFirstStage func(sdCTX, ggmlCTX, inputTX, outputTX unsafe.Pointer)
2928

@@ -63,13 +62,13 @@ func New(in *opts.Options) (*Generator, error) {
6362
purego.RegisterLibFunc(&impl.GGML.TensorClamp, libSd, "go_ggml_tensor_clamp")
6463

6564
purego.RegisterLibFunc(&impl.GGML.TensorGetF32, libSd, "go_ggml_tensor_get_f32")
65+
// purego.RegisterLibFunc(&impl.GGML.VectorToGgmlTensorI32, libSd, "go_vector_to_ggml_tensor_i32")
66+
67+
// purego.RegisterLibFunc(&impl.SetLearnedCondition, libSd, "go_set_learned_condition")
68+
purego.RegisterLibFunc(&impl.ApplyLora, libSd, "apply_lora")
6669

67-
purego.RegisterLibFunc(&impl.GetLearnedCondition, libSd, "go_get_learned_condition")
68-
purego.RegisterLibFunc(&impl.PairGet, libSd, "go_pair_get")
6970
purego.RegisterLibFunc(&impl.GoSample, libSd, "go_sample")
7071
purego.RegisterLibFunc(&impl.DecodeFirstStage, libSd, "go_decode_first_stage")
7172

72-
purego.RegisterLibFunc(&impl.ApplyLora, libSd, "apply_lora")
73-
7473
return &impl, err
7574
}

pkg/txt2img/tokenize.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,14 @@ import (
1010
"strings"
1111
)
1212

13-
func tokenize(text string, padding bool) (tokens []int, weights []float32) {
13+
const (
14+
UNK_TOKEN_ID int = 49407
15+
BOS_TOKEN_ID int = 49406
16+
EOS_TOKEN_ID int = 49407
17+
PAD_TOKEN_ID int = 49407
18+
)
19+
20+
func (gen *Generator) Tokenize(text string, padding bool) (tokens []int, weights []float32) {
1421
tokens = make([]int, 0)
1522
weights = make([]float32, 0)
1623

0 commit comments

Comments
 (0)