Skip to content

Commit fd28505

Browse files
authored
Merge pull request #1402 from xiaomakuaiz/feat-search-simlarity
feat: change search similarity to 0.2
2 parents cb33999 + 70a7b1b commit fd28505

File tree

4 files changed

+19
-9
lines changed

4 files changed

+19
-9
lines changed

backend/store/rag/ct/rag.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (s *CTRAG) CreateKnowledgeBase(ctx context.Context) (string, error) {
4747
return dataset.ID, nil
4848
}
4949

50-
func (s *CTRAG) QueryRecords(ctx context.Context, datasetIDs []string, query string, groupIds []int, historyMsgs []*schema.Message) ([]*domain.NodeContentChunk, error) {
50+
func (s *CTRAG) QueryRecords(ctx context.Context, datasetIDs []string, query string, groupIds []int, similarityThreshold float64, historyMsgs []*schema.Message) ([]*domain.NodeContentChunk, error) {
5151
var chatMsgs []rag.ChatMessage
5252
for _, msg := range historyMsgs {
5353
switch msg.Role {
@@ -66,14 +66,17 @@ func (s *CTRAG) QueryRecords(ctx context.Context, datasetIDs []string, query str
6666
}
6767
}
6868
s.logger.Debug("retrieving by history msgs", log.Any("history_msgs", historyMsgs), log.Any("chat_msgs", chatMsgs))
69-
chunks, _, rewriteQuery, err := s.client.RetrieveChunks(ctx, rag.RetrievalRequest{
69+
retrieveReq := rag.RetrievalRequest{
7070
DatasetIDs: datasetIDs,
7171
Question: query,
7272
TopK: 10,
7373
UserGroupIDs: groupIds,
7474
ChatMessages: chatMsgs,
75-
// SimilarityThreshold: 0.2,
76-
})
75+
}
76+
if similarityThreshold != 0 {
77+
retrieveReq.SimilarityThreshold = similarityThreshold
78+
}
79+
chunks, _, rewriteQuery, err := s.client.RetrieveChunks(ctx, retrieveReq)
7780
s.logger.Info("retrieve chunks result", log.Int("chunks count", len(chunks)), log.String("query", rewriteQuery))
7881

7982
if err != nil {

backend/store/rag/rag.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
type RAGService interface {
1717
CreateKnowledgeBase(ctx context.Context) (string, error)
1818
UpsertRecords(ctx context.Context, datasetID string, nodeRelease *domain.NodeReleaseWithDirPath, authGroupId []int) (string, error)
19-
QueryRecords(ctx context.Context, datasetIDs []string, query string, groupIDs []int, historyMsgs []*schema.Message) ([]*domain.NodeContentChunk, error)
19+
QueryRecords(ctx context.Context, datasetIDs []string, query string, groupIDs []int, similarityThreshold float64, historyMsgs []*schema.Message) ([]*domain.NodeContentChunk, error)
2020
DeleteRecords(ctx context.Context, datasetID string, docIDs []string) error
2121
DeleteKnowledgeBase(ctx context.Context, datasetID string) error
2222
UpdateDocumentGroupIDs(ctx context.Context, datasetID string, docID string, groupIds []int) error

backend/usecase/chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ func (u *ChatUsecase) Search(ctx context.Context, req *domain.ChatSearchReq) (*d
372372
if err != nil {
373373
return nil, err
374374
}
375-
rankedNodes, err := u.llmUsecase.GetRankNodes(ctx, []string{kb.DatasetID}, req.Message, groupIds, nil)
375+
rankedNodes, err := u.llmUsecase.GetRankNodes(ctx, []string{kb.DatasetID}, req.Message, groupIds, 0.2, nil)
376376
if err != nil {
377377
return nil, err
378378
}

backend/usecase/llm.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func (u *LLMUsecase) FormatConversationMessages(
100100
if err != nil {
101101
return nil, nil, fmt.Errorf("get kb failed: %w", err)
102102
}
103-
rankedNodes, err = u.GetRankNodes(ctx, []string{kb.DatasetID}, question, groupIDs, historyMessages[:len(historyMessages)-1])
103+
rankedNodes, err = u.GetRankNodes(ctx, []string{kb.DatasetID}, question, groupIDs, 0, historyMessages[:len(historyMessages)-1])
104104
if err != nil {
105105
return nil, nil, fmt.Errorf("get rank nodes failed: %w", err)
106106
}
@@ -297,10 +297,17 @@ func (u *LLMUsecase) SplitByTokenLimit(text string, maxTokens int) ([]string, er
297297
return result, nil
298298
}
299299

300-
func (u *LLMUsecase) GetRankNodes(ctx context.Context, datasetIDs []string, question string, groupIDs []int, historyMessages []*schema.Message) ([]*domain.RankedNodeChunks, error) {
300+
func (u *LLMUsecase) GetRankNodes(
301+
ctx context.Context,
302+
datasetIDs []string,
303+
question string,
304+
groupIDs []int,
305+
similarityThreshold float64,
306+
historyMessages []*schema.Message,
307+
) ([]*domain.RankedNodeChunks, error) {
301308
var rankedNodes []*domain.RankedNodeChunks
302309
// get related documents from raglite
303-
records, err := u.rag.QueryRecords(ctx, datasetIDs, question, groupIDs, historyMessages)
310+
records, err := u.rag.QueryRecords(ctx, datasetIDs, question, groupIDs, similarityThreshold, historyMessages)
304311
if err != nil {
305312
return nil, fmt.Errorf("get records from raglite failed: %w", err)
306313
}

0 commit comments

Comments
 (0)