Skip to content

Commit

Permalink
Db curd (#30)
Browse files Browse the repository at this point in the history
* implement rag related functions

Signed-off-by: cbh778899 <[email protected]>

* add types

Signed-off-by: cbh778899 <[email protected]>

* rewrite index

Signed-off-by: cbh778899 <[email protected]>

---------

Signed-off-by: cbh778899 <[email protected]>
  • Loading branch information
cbh778899 committed Aug 5, 2024
1 parent 286b63d commit d49634c
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 35 deletions.
66 changes: 31 additions & 35 deletions database/index.js
Original file line number Diff line number Diff line change
@@ -1,42 +1,38 @@
import * as lancedb from "@lancedb/lancedb";
import { get, post } from "../tools/request.js"
import { Schema, Field, FixedSizeList, Int16, Float16, Utf8 } from "apache-arrow";
import { connect } from "@lancedb/lancedb";
import {
Schema, Field, FixedSizeList,
Float32, Utf8,
// eslint-disable-next-line
Table
} from "apache-arrow";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types";

const uri = "/tmp/lancedb/";
const db = await lancedb.connect(uri);
const db = await connect(uri);

const table = await db.createEmptyTable("rag_data", new Schema([
new Field("id", new Int16()),
new Field("vector", new FixedSizeList(384, new Field("item", new Float16(), true)), false),
new Field("question", new Utf8()),
new Field("answer", new Utf8())
]), {
// mode: "overwrite",
existOk: true
})

export async function loadDataset(dataset_link) {
const {rows, http_error} = await get('', {}, { URL: dataset_link })
if(http_error) {
return false;
}
await table.add(rows.map(({ row_id, row })=>{
const { question, answer, question_embedding } = row;
return { id: row_id, question, answer, vector: question_embedding }
}))
return true;
export async function initDB(force = false) {
const open_options = force ? { mode: "overwrite" } : { existOk: true }
// create or re-open system table to store long-lasting data
await db.createEmptyTable(SYSTEM_TABLE, new Schema([
new Field("title", new Utf8()),
new Field("value", new Utf8())
]), open_options)
// create or re-open dataset table
await db.createEmptyTable(DATASET_TABLE, new Schema([
new Field("vector", new FixedSizeList(384, new Field("item", new Float32(), true)), false),
new Field("dataset_name", new Utf8()),
new Field("question", new Utf8()),
new Field("answer", new Utf8())
]), open_options)
}

export async function searchByEmbedding(vector) {
const record = await table.search(vector).limit(1).toArray();
if(!record.length) return null;
const { question, answer } = record[0];
return { question, answer };
}
initDB();

export async function searchByMessage(msg) {
const { embedding } = await post('embedding', {body: {
content: msg
}}, { eng: "embedding" });
return await searchByEmbedding(embedding);
/**
* Open a table with table name
* @param {String} table_name table name to be opened
* @returns {Promise<Table>} Promise containes the table object.
*/
export async function getTable(table_name) {
return await db.openTable(table_name)
}
78 changes: 78 additions & 0 deletions database/rag-inference.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import { get, post } from "../tools/request.js";
import { getTable } from "./index.js";
import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js";

async function loadDatasetFromURL(dataset_name, dataset_url, system_table) {
system_table = system_table || await getTable(SYSTEM_TABLE);
const { rows, http_error } = await get('', {}, {URL: dataset_url});
if(http_error) return false;

await system_table.add([{ title: "loaded_dataset_name", value: dataset_name }]);

await (await getTable(DATASET_TABLE)).add(rows.map(({row})=>{
const { question, answer, question_embedding } = row;
return { question, answer, vector: question_embedding, dataset_name }
}))
return true;
}

/**
* Load a dataset from given url.
* * This will first check whether the dataset is loaded in database, if `force` not provided and it's loaded already, it won't load again.
* * The dataset format should be an array of object contains at least `question`, `answer` and `question_embedding` properties
* @param {String} dataset_name The dataset name to load
* @param {String} dataset_url The url of dataset to load
* @param {Boolean} force Specify whether to force load the dataset, default `false`.
* @returns {Promise<Boolean>} If cannot get the dataset, return `false`, otherwise return `true`
*/
export async function loadDataset(dataset_name, dataset_url, force = false) {
const system_table = await getTable(SYSTEM_TABLE)
if(!force) {
const loaded_dataset = await system_table.query()
.where(`title="loaded_dataset_name" AND value="${dataset_name}"`).toArray();
// check if the given dataset loaded, if not, load the dataset
return !!(loaded_dataset.length || await loadDatasetFromURL(dataset_name, dataset_url, system_table))
} else {
return await loadDatasetFromURL(dataset_name, dataset_url, system_table)
}
}

/**
* @typedef EmbeddingSearchResult
* @property {String} question The question from dataset
* @property {String} answer The answer from dataset
*/

/**
* Search in given dataset using provided embedding value to get Q/A pair
* @param {String} dataset_name The dataset name to be query from
* @param {Array<Float>} vector The embedding result to be searched
* @returns {Promise<EmbeddingSearchResult|null>} If there's no result, returns null, otherwise returns the result
*/
export async function searchByEmbedding(dataset_name, vector) {
const embedding_result = (await (
await getTable(DATASET_TABLE)
).search(vector).where(`dataset_name = "${dataset_name}"`)
.limit(1).toArray()).pop();

if(embedding_result) {
const { question, answer, _distance } = embedding_result;
return { question, answer, _distance }
}
return null;
}

/**
* Search in given dataset using provided message to get Q/A pair.
* This will firstly embedding the message and query use {@link searchByEmbedding}
* @param {String} dataset_name The dataset name to be query from
* @param {String} message The message to be searched
* @returns {Promise<EmbeddingSearchResult|null>} If there's no result, returns null, otherwise returns the result
*/
export async function searchByMessage(dataset_name, message) {
const { embedding, http_error } = await post('embedding', {body: {
content: message
}}, { eng: "embedding" });

return http_error ? null : await searchByEmbedding(dataset_name, embedding);
}
2 changes: 2 additions & 0 deletions database/types.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
export const SYSTEM_TABLE = 'system';
export const DATASET_TABLE = 'dataset';

0 comments on commit d49634c

Please sign in to comment.