diff --git a/backend/go/stores/main.go b/backend/go/stores/main.go index 9a113d7951d..87feebe7a91 100644 --- a/backend/go/stores/main.go +++ b/backend/go/stores/main.go @@ -19,8 +19,12 @@ func main() { log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) flag.Parse() + s, err := NewStore() + if err != nil { + panic(err) + } - if err := grpc.StartServer(*addr, NewStore()); err != nil { + if err := grpc.StartServer(*addr, s); err != nil { panic(err) } } diff --git a/backend/go/stores/store.go b/backend/go/stores/store.go index 9be31df8de3..313d770a912 100644 --- a/backend/go/stores/store.go +++ b/backend/go/stores/store.go @@ -3,505 +3,90 @@ package main // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( - "container/heap" + "context" "fmt" - "math" - "slices" + "strconv" "github.com/go-skynet/LocalAI/pkg/grpc/base" pb "github.com/go-skynet/LocalAI/pkg/grpc/proto" - "github.com/rs/zerolog/log" + "github.com/philippgille/chromem-go" ) type Store struct { base.SingleThread - // The sorted keys - keys [][]float32 - // The sorted values - values [][]byte - - // If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions - // TODO: Should we normalize incoming keys if they are not instead? - keysAreNormalized bool - // The first key decides the length of the keys - keyLen int -} - -// TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because -// that's theoretically best for memory layout and cache locality, but this isn't optimized yet. -type Pair struct { - Key []float32 - Value []byte + maxId int + db *chromem.DB + c *chromem.Collection } -func NewStore() *Store { - return &Store{ - keys: make([][]float32, 0), - values: make([][]byte, 0), - keysAreNormalized: true, - keyLen: -1, +func NewStore() (*Store, error) { + db := chromem.NewDB() + c, err := db.CreateCollection("default", nil, nil) + if err != nil { + return nil, err } -} -func compareSlices(k1, k2 []float32) int { - assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - return slices.Compare(k1, k2) -} - -func hasKey(unsortedSlice [][]float32, target []float32) bool { - return slices.ContainsFunc(unsortedSlice, func(k []float32) bool { - return compareSlices(k, target) == 0 - }) -} - -func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) { - return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int { - return compareSlices(k, t) - }) -} - -func isSortedPairs(kvs []Pair) bool { - for i := 1; i < len(kvs); i++ { - if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 { - return false - } - } - - return true -} - -func isSortedKeys(keys [][]float32) bool { - for i := 1; i < len(keys); i++ { - if compareSlices(keys[i-1], keys[i]) > 0 { - return false - } - } - - return true -} - -func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 { - ks := make([][]float32, len(keys)) - - for i, k := range keys { - ks[i] = k.Floats - } - - slices.SortFunc(ks, compareSlices) - - assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys))) - assert(isSortedKeys(ks), "keys are not sorted") - - return ks + return &Store{ + db: db, + c: c, + }, nil } func (s *Store) Load(opts *pb.ModelOptions) error { return nil } -// Sort the incoming kvs and merge them with the existing sorted kvs func (s *Store) StoresSet(opts *pb.StoresSetOptions) error { - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to add") - } - - if len(opts.Keys) != len(opts.Values) { - return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values)) - } + ids := make([]string, len(opts.Keys)) - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } + for i, _ := range(ids) { + ids[i] = strconv.Itoa(i) } - kvs := make([]Pair, len(opts.Keys)) - - for i, k := range opts.Keys { - if s.keysAreNormalized && !isNormalized(k.Floats) { - s.keysAreNormalized = false - var sample []float32 - if len(s.keys) > 5 { - sample = k.Floats[:5] - } else { - sample = k.Floats - } - log.Debug().Msgf("Key is not normalized: %v", sample) - } + embeddings := make([][]float32, len(opts.Keys)) - kvs[i] = Pair{ - Key: k.Floats, - Value: opts.Values[i].Bytes, - } + for i, key := range opts.Keys { + embeddings[i] = key.Floats } - slices.SortFunc(kvs, func(a, b Pair) int { - return compareSlices(a.Key, b.Key) - }) - - assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys))) - assert(isSortedPairs(kvs), "keys are not sorted") - - l := len(kvs) + len(s.keys) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) + contents := make([]string, len(opts.Values)) - i, j := 0, 0 - for { - if i+j >= l { - break - } - - if i >= len(kvs) { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - continue - } - - if j >= len(s.keys) { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - continue - } - - c := compareSlices(kvs[i].Key, s.keys[j]) - if c < 0 { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - } else if c > 0 { - merge_ks = append(merge_ks, s.keys[j]) - merge_vs = append(merge_vs, s.values[j]) - j++ - } else { - merge_ks = append(merge_ks, kvs[i].Key) - merge_vs = append(merge_vs, kvs[i].Value) - i++ - j++ - } + for i, value := range opts.Values { + contents[i] = string(value.Bytes) } - assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l)) - assert(isSortedKeys(merge_ks), "merge keys are not sorted") - - s.keys = merge_ks - s.values = merge_vs - - return nil + return s.c.Add(context.Background(), ids, embeddings, nil, contents) } func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error { - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to delete") - } - - if len(opts.Keys) == 0 { - return fmt.Errorf("no keys to add") - } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } - } - - ks := sortIntoKeySlicese(opts.Keys) - - l := len(s.keys) - len(ks) - merge_ks := make([][]float32, 0, l) - merge_vs := make([][]byte, 0, l) - - tail_ks := s.keys - tail_vs := s.values - for _, k := range ks { - j, found := findInSortedSlice(tail_ks, k) - - if found { - merge_ks = append(merge_ks, tail_ks[:j]...) - merge_vs = append(merge_vs, tail_vs[:j]...) - tail_ks = tail_ks[j+1:] - tail_vs = tail_vs[j+1:] - } else { - assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k)) - } - - log.Debug().Msgf("Delete: found = %v, t = %d, j = %d, len(merge_ks) = %d, len(merge_vs) = %d", found, len(tail_ks), j, len(merge_ks), len(merge_vs)) - } - - merge_ks = append(merge_ks, tail_ks...) - merge_vs = append(merge_vs, tail_vs...) - - assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys))) - - s.keys = merge_ks - s.values = merge_vs - - assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l)) - assert(isSortedKeys(s.keys), "keys are not sorted") - assert(func() bool { - for _, k := range ks { - if _, found := findInSortedSlice(s.keys, k); found { - return false - } - } - return true - }(), "Keys to delete still present") - - if len(s.keys) != l { - log.Debug().Msgf("Delete: Some keys not found: len(s.keys) = %d, l = %d", len(s.keys), l) - } - - return nil + return fmt.Errorf("Per document delete not implemented in chromem") } func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) { - pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys)) - pbValues := make([]*pb.StoresValue, 0, len(opts.Keys)) - ks := sortIntoKeySlicese(opts.Keys) - - if len(s.keys) == 0 { - log.Debug().Msgf("Get: No keys in store") - } - - if s.keyLen == -1 { - s.keyLen = len(opts.Keys[0].Floats) - } else { - if len(opts.Keys[0].Floats) != s.keyLen { - return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) - } - } - - tail_k := s.keys - tail_v := s.values - for i, k := range ks { - j, found := findInSortedSlice(tail_k, k) - - if found { - pbKeys = append(pbKeys, &pb.StoresKey{ - Floats: k, - }) - pbValues = append(pbValues, &pb.StoresValue{ - Bytes: tail_v[j], - }) - - tail_k = tail_k[j+1:] - tail_v = tail_v[j+1:] - } else { - assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k)) - } - } - - if len(pbKeys) != len(opts.Keys) { - log.Debug().Msgf("Get: Some keys not found: len(pbKeys) = %d, len(opts.Keys) = %d, len(s.Keys) = %d", len(pbKeys), len(opts.Keys), len(s.keys)) - } - - return pb.StoresGetResult{ - Keys: pbKeys, - Values: pbValues, - }, nil -} - -func isNormalized(k []float32) bool { - var sum float32 - for _, v := range k { - sum += v - } - - return sum == 1.0 -} - -// TODO: This we could replace with handwritten SIMD code -func normalizedCosineSimilarity(k1, k2 []float32) float32 { - assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - var dot float32 - for i := 0; i < len(k1); i++ { - dot += k1[i] * k2[i] - } - - assert(dot >= -1 && dot <= 1, fmt.Sprintf("dot = %f", dot)) - - // 2.0 * (1.0 - dot) would be the Euclidean distance - return dot -} - -type PriorityItem struct { - Similarity float32 - Key []float32 - Value []byte + return pb.StoresGetResult{}, fmt.Errorf("Get not really implemented in chromem, although query may work") } -type PriorityQueue []*PriorityItem - -func (pq PriorityQueue) Len() int { return len(pq) } - -func (pq PriorityQueue) Less(i, j int) bool { - // Inverted because the most similar should be at the top - return pq[i].Similarity < pq[j].Similarity -} - -func (pq PriorityQueue) Swap(i, j int) { - pq[i], pq[j] = pq[j], pq[i] -} - -func (pq *PriorityQueue) Push(x any) { - item := x.(*PriorityItem) - *pq = append(*pq, item) -} - -func (pq *PriorityQueue) Pop() any { - old := *pq - n := len(old) - item := old[n-1] - *pq = old[0 : n-1] - return item -} - -func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - top_ks := make(PriorityQueue, 0, int(opts.TopK)) - heap.Init(&top_ks) - - for i, k := range s.keys { - sim := normalizedCosineSimilarity(tk, k) - heap.Push(&top_ks, &PriorityItem{ - Similarity: sim, - Key: k, - Value: s.values[i], - }) - - if top_ks.Len() > int(opts.TopK) { - heap.Pop(&top_ks) - } - } - - similarities := make([]float32, top_ks.Len()) - pbKeys := make([]*pb.StoresKey, top_ks.Len()) - pbValues := make([]*pb.StoresValue, top_ks.Len()) - - for i := top_ks.Len() - 1; i >= 0; i-- { - item := heap.Pop(&top_ks).(*PriorityItem) - - similarities[i] = item.Similarity - pbKeys[i] = &pb.StoresKey{ - Floats: item.Key, - } - pbValues[i] = &pb.StoresValue{ - Bytes: item.Value, - } - } - - return pb.StoresFindResult{ - Keys: pbKeys, - Values: pbValues, - Similarities: similarities, - }, nil -} - -func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 { - assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) - - var dot, mag2 float64 - for i := 0; i < len(k1); i++ { - dot += float64(k1[i] * k2[i]) - mag2 += float64(k2[i] * k2[i]) - } - - sim := float32(dot / (mag1 * math.Sqrt(mag2))) - - assert(sim >= -1 && sim <= 1, fmt.Sprintf("sim = %f", sim)) - - return sim -} - -func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - top_ks := make(PriorityQueue, 0, int(opts.TopK)) - heap.Init(&top_ks) - - var mag1 float64 - for _, v := range tk { - mag1 += float64(v * v) - } - mag1 = math.Sqrt(mag1) - - for i, k := range s.keys { - dist := cosineSimilarity(tk, k, mag1) - heap.Push(&top_ks, &PriorityItem{ - Similarity: dist, - Key: k, - Value: s.values[i], - }) - - if top_ks.Len() > int(opts.TopK) { - heap.Pop(&top_ks) - } +func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { + res, err := s.c.QueryEmbedding(context.Background(), opts.Key.Floats, int(opts.TopK), nil, nil) + if err != nil { + return pb.StoresFindResult{}, err } - similarities := make([]float32, top_ks.Len()) - pbKeys := make([]*pb.StoresKey, top_ks.Len()) - pbValues := make([]*pb.StoresValue, top_ks.Len()) - - for i := top_ks.Len() - 1; i >= 0; i-- { - item := heap.Pop(&top_ks).(*PriorityItem) + keys := make([]*pb.StoresKey, len(res)) + values := make([]*pb.StoresValue, len(res)) + similarities := make([]float32, len(res)) - similarities[i] = item.Similarity - pbKeys[i] = &pb.StoresKey{ - Floats: item.Key, - } - pbValues[i] = &pb.StoresValue{ - Bytes: item.Value, - } + for i, r := range(res) { + keys[i] = &pb.StoresKey{Floats: r.Embedding} + similarities[i] = r.Similarity + values[i] = &pb.StoresValue{Bytes: []byte(r.Content)} } return pb.StoresFindResult{ - Keys: pbKeys, - Values: pbValues, + Keys: keys, + Values: values, Similarities: similarities, }, nil } - -func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { - tk := opts.Key.Floats - - if len(tk) != s.keyLen { - return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen) - } - - if opts.TopK < 1 { - return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK) - } - - if s.keyLen == -1 { - s.keyLen = len(opts.Key.Floats) - } else { - if len(opts.Key.Floats) != s.keyLen { - return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen) - } - } - - if s.keysAreNormalized && isNormalized(tk) { - return s.StoresFindNormalized(opts) - } else { - if s.keysAreNormalized { - var sample []float32 - if len(s.keys) > 5 { - sample = tk[:5] - } else { - sample = tk - } - log.Debug().Msgf("Trying to compare non-normalized key with normalized keys: %v", sample) - } - - return s.StoresFindFallback(opts) - } -} diff --git a/go.mod b/go.mod index 99af8ce7957..627c1c8ac19 100644 --- a/go.mod +++ b/go.mod @@ -110,6 +110,7 @@ require ( github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.0.2 // indirect github.com/opencontainers/runc v1.1.12 // indirect + github.com/philippgille/chromem-go v0.5.0 // indirect github.com/pierrec/lz4/v4 v4.1.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pkoukk/tiktoken-go v0.1.2 // indirect diff --git a/go.sum b/go.sum index a421e79c685..9a36a2cd7da 100644 --- a/go.sum +++ b/go.sum @@ -254,6 +254,8 @@ github.com/otiai10/openaigo v1.6.0 h1:YTQEbtDSvawETOB/Kmb/6JvuHdHH/eIpSQfHVufiwY github.com/otiai10/openaigo v1.6.0/go.mod h1:kIaXc3V+Xy5JLplcBxehVyGYDtufHp3PFPy04jOwOAI= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5 h1:Ii+DKncOVM8Cu1Hc+ETb5K+23HdAMvESYE3ZJ5b5cMI= github.com/phayes/freeport v0.0.0-20220201140144-74d24b5ae9f5/go.mod h1:iIss55rKnNBTvrwdmkUpLnDpZoAHvWaiq5+iMmen4AE= +github.com/philippgille/chromem-go v0.5.0 h1:bryX0F3N6jnN/21iBd8i2/k9EzPTZn3nyiqAti19si8= +github.com/philippgille/chromem-go v0.5.0/go.mod h1:hTd+wGEm/fFPQl7ilfCwQXkgEUxceYh86iIdoKMolPo= github.com/pierrec/lz4/v4 v4.1.2 h1:qvY3YFXRQE/XB8MlLzJH7mSzBs74eA2gg52YTk6jUPM= github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=