@@ -4,7 +4,6 @@ package txt2img
44
55import (
66 "math"
7- "time"
87 "unsafe"
98)
109
@@ -14,82 +13,19 @@ type ggmlTensor struct {
1413}
1514
1615var (
17- n_threads = 4
18- diffusion_model = & DiffusionModel {}
16+ nThreads = 4
1917)
2018
21- func ggmlTimeMs () int64 {
22- return time .Now ().UnixNano () / int64 (time .Millisecond )
23- }
24-
25- func vectorToGgmlTensorI32 (workCtx unsafe.Pointer , vec []int ) * ggmlTensor {
26- panic ("fix me" )
27- return & ggmlTensor {}
28- }
29-
30- func ggmlDupTensor (workCtx unsafe.Pointer , tensor * ggmlTensor ) * ggmlTensor {
31- panic ("fix me" )
32- return & ggmlTensor {}
33- }
34-
35- func ggmlTensorMean (tensor * ggmlTensor ) float32 {
36- panic ("fix me" )
37- return 0.0
38- }
39-
40- func ggmlTensorGetF32 (tensor * ggmlTensor , i0 , i1 , i2 int ) float32 {
41- panic ("fix me" )
42- return 0.0
43- }
44-
45- func ggmlTensorSetF32 (tensor * ggmlTensor , value float32 , i0 , i1 , i2 int ) {
46- panic ("fix me" )
47- }
48-
49- func ggmlTensorScale (tensor * ggmlTensor , scale float32 ) {
50- panic ("fix me" )
51- }
52-
53- func ggmlNelements (tensor * ggmlTensor ) int {
54- panic ("fix me" )
55- return 0
56- }
57-
58- func vectorToGgmlTensor (workCtx unsafe.Pointer , vec []float32 ) * ggmlTensor {
59- panic ("fix me" )
60- return & ggmlTensor {}
61- }
62-
63- func ggmlReshape2D (workCtx unsafe.Pointer , tensor * ggmlTensor , dim0 , dim1 int ) * ggmlTensor {
64- panic ("fix me" )
65- return & ggmlTensor {}
66- }
67-
68- func ggmlNewTensor1D (workCtx unsafe.Pointer , tensorType int , size int ) * ggmlTensor {
69- panic ("fix me" )
70- return & ggmlTensor {}
71- }
72-
73- func ggmlView2D (workCtx unsafe.Pointer , tensor * ggmlTensor , dim0 , dim1 , stride , offset int ) * ggmlTensor {
74- panic ("fix me" )
75- return & ggmlTensor {}
76- }
77-
78- func ggmlNbytes (tensor * ggmlTensor ) int {
79- panic ("fix me" )
80- return 0
81- }
82-
8319func setTimestepEmbedding (timesteps []float32 , embedView * ggmlTensor , outDim int ) {
8420 panic ("fix me" )
8521}
8622
87- func getLearnedCondition ( workCtx unsafe.Pointer , text string , clipSkip , width , height int , forceZeroEmbeddings bool ) (* ggmlTensor , * ggmlTensor ) {
88- tokens , weights := tokenize ( text , true )
89- return getLearnedConditionCommon (workCtx , tokens , weights , clipSkip , width , height , forceZeroEmbeddings )
23+ func ( gen * Generator ) GetLearnedCondition ( workCtx unsafe.Pointer , prompt string , clipSkip , width , height int , forceZeroEmbeddings bool ) (unsafe. Pointer , unsafe. Pointer ) {
24+ tokens , weights := gen . Tokenize ( prompt , true )
25+ return gen . getLearnedConditionCommon (workCtx , tokens , weights , clipSkip , width , height , forceZeroEmbeddings )
9026}
9127
92- func getLearnedConditionCommon (workCtx unsafe.Pointer , tokens []int , weights []float32 , clipSkip , width , height int , forceZeroEmbeddings bool ) (* ggmlTensor , * ggmlTensor ) {
28+ func ( gen * Generator ) getLearnedConditionCommon (workCtx unsafe.Pointer , tokens []int , weights []float32 , clipSkip , width , height int , forceZeroEmbeddings bool ) (unsafe. Pointer , unsafe. Pointer ) {
9329 cond_stage_model .setClipSkip (clipSkip )
9430 var hiddenStates , chunkHiddenStates , pooled * ggmlTensor
9531 var hiddenStatesVec []float32
@@ -100,7 +36,7 @@ func getLearnedConditionCommon(workCtx unsafe.Pointer, tokens []int, weights []f
10036 chunkTokens := tokens [chunkIdx * chunkLen : (chunkIdx + 1 )* chunkLen ]
10137 chunkWeights := weights [chunkIdx * chunkLen : (chunkIdx + 1 )* chunkLen ]
10238
103- inputIds := vectorToGgmlTensorI32 (workCtx , chunkTokens )
39+ inputIds := gen . GGML . VectorToGgmlTensorI32 (workCtx , chunkTokens )
10440 var inputIds2 * ggmlTensor
10541 var maxTokenIdx int
10642 // if version == VERSION_XL {
@@ -119,9 +55,9 @@ func getLearnedConditionCommon(workCtx unsafe.Pointer, tokens []int, weights []f
11955 inputIds2 = vectorToGgmlTensorI32 (workCtx , chunkTokens )
12056 // }
12157
122- cond_stage_model .compute (n_threads , inputIds , inputIds2 , maxTokenIdx , false , & chunkHiddenStates , workCtx )
58+ cond_stage_model .compute (nThreads , inputIds , inputIds2 , maxTokenIdx , false , & chunkHiddenStates , workCtx )
12359 if chunkIdx == 0 {
124- cond_stage_model .compute (n_threads , inputIds , inputIds2 , maxTokenIdx , true , & pooled , workCtx )
60+ cond_stage_model .compute (nThreads , inputIds , inputIds2 , maxTokenIdx , true , & pooled , workCtx )
12561 }
12662
12763 result := ggmlDupTensor (workCtx , chunkHiddenStates )
0 commit comments