From d49634c2dc36881bd85d2525e8ce231ca84c1f97 Mon Sep 17 00:00:00 2001 From: Bohan Cheng <47214785+cbh778899@users.noreply.github.com> Date: Mon, 5 Aug 2024 15:40:30 +1000 Subject: [PATCH] Db curd (#30) * implement rag related functions Signed-off-by: cbh778899 * add types Signed-off-by: cbh778899 * rewrite index Signed-off-by: cbh778899 --------- Signed-off-by: cbh778899 --- database/index.js | 66 ++++++++++++++++----------------- database/rag-inference.js | 78 +++++++++++++++++++++++++++++++++++++++ database/types.js | 2 + 3 files changed, 111 insertions(+), 35 deletions(-) create mode 100644 database/rag-inference.js create mode 100644 database/types.js diff --git a/database/index.js b/database/index.js index 9b029af..dc55811 100644 --- a/database/index.js +++ b/database/index.js @@ -1,42 +1,38 @@ -import * as lancedb from "@lancedb/lancedb"; -import { get, post } from "../tools/request.js" -import { Schema, Field, FixedSizeList, Int16, Float16, Utf8 } from "apache-arrow"; +import { connect } from "@lancedb/lancedb"; +import { + Schema, Field, FixedSizeList, + Float32, Utf8, + // eslint-disable-next-line + Table +} from "apache-arrow"; +import { DATASET_TABLE, SYSTEM_TABLE } from "./types"; const uri = "/tmp/lancedb/"; -const db = await lancedb.connect(uri); +const db = await connect(uri); -const table = await db.createEmptyTable("rag_data", new Schema([ - new Field("id", new Int16()), - new Field("vector", new FixedSizeList(384, new Field("item", new Float16(), true)), false), - new Field("question", new Utf8()), - new Field("answer", new Utf8()) -]), { - // mode: "overwrite", - existOk: true -}) - -export async function loadDataset(dataset_link) { - const {rows, http_error} = await get('', {}, { URL: dataset_link }) - if(http_error) { - return false; - } - await table.add(rows.map(({ row_id, row })=>{ - const { question, answer, question_embedding } = row; - return { id: row_id, question, answer, vector: question_embedding } - })) - return true; +export async function initDB(force = false) { + const open_options = force ? { mode: "overwrite" } : { existOk: true } + // create or re-open system table to store long-lasting data + await db.createEmptyTable(SYSTEM_TABLE, new Schema([ + new Field("title", new Utf8()), + new Field("value", new Utf8()) + ]), open_options) + // create or re-open dataset table + await db.createEmptyTable(DATASET_TABLE, new Schema([ + new Field("vector", new FixedSizeList(384, new Field("item", new Float32(), true)), false), + new Field("dataset_name", new Utf8()), + new Field("question", new Utf8()), + new Field("answer", new Utf8()) + ]), open_options) } -export async function searchByEmbedding(vector) { - const record = await table.search(vector).limit(1).toArray(); - if(!record.length) return null; - const { question, answer } = record[0]; - return { question, answer }; -} +initDB(); -export async function searchByMessage(msg) { - const { embedding } = await post('embedding', {body: { - content: msg - }}, { eng: "embedding" }); - return await searchByEmbedding(embedding); +/** + * Open a table with table name + * @param {String} table_name table name to be opened + * @returns {Promise} Promise containes the table object. + */ +export async function getTable(table_name) { + return await db.openTable(table_name) } \ No newline at end of file diff --git a/database/rag-inference.js b/database/rag-inference.js new file mode 100644 index 0000000..3e0d350 --- /dev/null +++ b/database/rag-inference.js @@ -0,0 +1,78 @@ +import { get, post } from "../tools/request.js"; +import { getTable } from "./index.js"; +import { DATASET_TABLE, SYSTEM_TABLE } from "./types.js"; + +async function loadDatasetFromURL(dataset_name, dataset_url, system_table) { + system_table = system_table || await getTable(SYSTEM_TABLE); + const { rows, http_error } = await get('', {}, {URL: dataset_url}); + if(http_error) return false; + + await system_table.add([{ title: "loaded_dataset_name", value: dataset_name }]); + + await (await getTable(DATASET_TABLE)).add(rows.map(({row})=>{ + const { question, answer, question_embedding } = row; + return { question, answer, vector: question_embedding, dataset_name } + })) + return true; +} + +/** + * Load a dataset from given url. + * * This will first check whether the dataset is loaded in database, if `force` not provided and it's loaded already, it won't load again. + * * The dataset format should be an array of object contains at least `question`, `answer` and `question_embedding` properties + * @param {String} dataset_name The dataset name to load + * @param {String} dataset_url The url of dataset to load + * @param {Boolean} force Specify whether to force load the dataset, default `false`. + * @returns {Promise} If cannot get the dataset, return `false`, otherwise return `true` + */ +export async function loadDataset(dataset_name, dataset_url, force = false) { + const system_table = await getTable(SYSTEM_TABLE) + if(!force) { + const loaded_dataset = await system_table.query() + .where(`title="loaded_dataset_name" AND value="${dataset_name}"`).toArray(); + // check if the given dataset loaded, if not, load the dataset + return !!(loaded_dataset.length || await loadDatasetFromURL(dataset_name, dataset_url, system_table)) + } else { + return await loadDatasetFromURL(dataset_name, dataset_url, system_table) + } +} + +/** + * @typedef EmbeddingSearchResult + * @property {String} question The question from dataset + * @property {String} answer The answer from dataset + */ + +/** + * Search in given dataset using provided embedding value to get Q/A pair + * @param {String} dataset_name The dataset name to be query from + * @param {Array} vector The embedding result to be searched + * @returns {Promise} If there's no result, returns null, otherwise returns the result + */ +export async function searchByEmbedding(dataset_name, vector) { + const embedding_result = (await ( + await getTable(DATASET_TABLE) + ).search(vector).where(`dataset_name = "${dataset_name}"`) + .limit(1).toArray()).pop(); + + if(embedding_result) { + const { question, answer, _distance } = embedding_result; + return { question, answer, _distance } + } + return null; +} + +/** + * Search in given dataset using provided message to get Q/A pair. + * This will firstly embedding the message and query use {@link searchByEmbedding} + * @param {String} dataset_name The dataset name to be query from + * @param {String} message The message to be searched + * @returns {Promise} If there's no result, returns null, otherwise returns the result + */ +export async function searchByMessage(dataset_name, message) { + const { embedding, http_error } = await post('embedding', {body: { + content: message + }}, { eng: "embedding" }); + + return http_error ? null : await searchByEmbedding(dataset_name, embedding); +} \ No newline at end of file diff --git a/database/types.js b/database/types.js new file mode 100644 index 0000000..54a49be --- /dev/null +++ b/database/types.js @@ -0,0 +1,2 @@ +export const SYSTEM_TABLE = 'system'; +export const DATASET_TABLE = 'dataset'; \ No newline at end of file