Skip to content

Commit

Permalink
add max_distance to limit the result provided by vector db
Browse files Browse the repository at this point in the history
Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 committed Aug 6, 2024
1 parent 624f9b0 commit 5cb0f29
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions database/rag-inference.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,19 @@ export async function loadDataset(dataset_name, dataset_url, force = false) {
* Search in given dataset using provided embedding value to get Q/A pair
* @param {String} dataset_name The dataset name to be query from
* @param {Array<Float>} vector The embedding result to be searched
* @param {Number} max_distance If the calculated distance is over given max_distance, then the result will be excluded.
* Default to `1`.
* @returns {Promise<EmbeddingSearchResult|null>} If there's no result, returns null, otherwise returns the result
*/
export async function searchByEmbedding(dataset_name, vector) {
export async function searchByEmbedding(dataset_name, vector, max_distance = 1) {
const embedding_result = (await (
await getTable(DATASET_TABLE)
).search(vector).where(`dataset_name = "${dataset_name}"`)
.limit(1).toArray()).pop();

if(embedding_result) {
const { question, answer, _distance } = embedding_result;
if(_distance >= max_distance) return null;
return { question, answer, _distance }
}
return null;
Expand All @@ -82,12 +85,14 @@ export async function searchByEmbedding(dataset_name, vector) {
* This will firstly embedding the message and query use {@link searchByEmbedding}
* @param {String} dataset_name The dataset name to be query from
* @param {String} message The message to be searched
* @param {Number} max_distance If the calculated distance is over given max_distance, then the result will be excluded.
* Default to `1`.
* @returns {Promise<EmbeddingSearchResult|null>} If there's no result, returns null, otherwise returns the result
*/
export async function searchByMessage(dataset_name, message) {
export async function searchByMessage(dataset_name, message, max_distance = 1) {
const { embedding, http_error } = await post('embedding', {body: {
content: message
}}, { eng: "embedding" });

return http_error ? null : await searchByEmbedding(dataset_name, embedding);
return http_error ? null : await searchByEmbedding(dataset_name, embedding, max_distance);
}

0 comments on commit 5cb0f29

Please sign in to comment.