From ec1ccfc71e7b84dc944399c17d74d22bb4eb368e Mon Sep 17 00:00:00 2001 From: archer <545436317@qq.com> Date: Tue, 10 Sep 2024 22:48:29 +0800 Subject: [PATCH] perf: push data to training queue --- .../core/dataset/training/controller.ts | 108 ++++++++++-------- 1 file changed, 60 insertions(+), 48 deletions(-) diff --git a/packages/service/core/dataset/training/controller.ts b/packages/service/core/dataset/training/controller.ts index 984e4a3c05cd..e293a8183e7c 100644 --- a/packages/service/core/dataset/training/controller.ts +++ b/packages/service/core/dataset/training/controller.ts @@ -10,6 +10,7 @@ import { ClientSession } from '../../../common/mongo'; import { getLLMModel, getVectorModel } from '../../ai/model'; import { addLog } from '../../../common/system/log'; import { getCollectionWithDataset } from '../controller'; +import { mongoSessionRun } from '../../../common/mongo/sessionRun'; export const lockTrainingDataByTeamId = async (teamId: string): Promise => { try { @@ -64,7 +65,7 @@ export async function pushDataListToTrainingQueue({ vectorModel: string; session?: ClientSession; } & PushDatasetDataProps): Promise { - const checkModelValid = async () => { + const { model, maxToken, weight } = await (async () => { const agentModelData = getLLMModel(agentModel); if (!agentModelData) { return Promise.reject(`File model ${agentModel} is inValid`); @@ -91,9 +92,16 @@ export async function pushDataListToTrainingQueue({ } return Promise.reject(`Training mode "${trainingMode}" is inValid`); - }; + })(); - const { model, maxToken, weight } = await checkModelValid(); + // filter repeat or equal content + const set = new Set(); + const filterResult: Record = { + success: [], + overToken: [], + repeat: [], + error: [] + }; // format q and a, remove empty char data.forEach((item) => { @@ -108,19 +116,8 @@ export async function pushDataListToTrainingQueue({ }; }) .filter(Boolean); - }); - - // filter repeat or equal content - const set = new Set(); - const filterResult: Record = { - success: [], - overToken: [], - repeat: [], - error: [] - }; - // filter repeat content - data.forEach((item) => { + // filter repeat content if (!item.q) { filterResult.error.push(item); return; @@ -150,40 +147,55 @@ export async function pushDataListToTrainingQueue({ const failedDocuments: PushDatasetDataChunkProps[] = []; // 使用 insertMany 批量插入 - try { - await MongoDatasetTraining.insertMany( - filterResult.success.map((item) => ({ - teamId, - tmbId, - datasetId, - collectionId, - billId, - mode: trainingMode, - prompt, - model, - q: item.q, - a: item.a, - chunkIndex: item.chunkIndex ?? 0, - weight: weight ?? 0, - indexes: item.indexes - })), - { - session, - ordered: false - } - ); - } catch (error: any) { - addLog.error(`Insert error`, error); - // 如果有错误,将失败的文档添加到失败列表中 - error.writeErrors?.forEach((writeError: any) => { - failedDocuments.push(data[writeError.index]); - }); - console.log('failed', failedDocuments); - } + const batchSize = 200; + const insertData = async (startIndex: number, session: ClientSession) => { + const list = filterResult.success.slice(startIndex, startIndex + batchSize); + + if (list.length === 0) return; + + try { + await MongoDatasetTraining.insertMany( + list.map((item) => ({ + teamId, + tmbId, + datasetId, + collectionId, + billId, + mode: trainingMode, + prompt, + model, + q: item.q, + a: item.a, + chunkIndex: item.chunkIndex ?? 0, + weight: weight ?? 0, + indexes: item.indexes + })), + { + session, + ordered: true + } + ); + } catch (error: any) { + addLog.error(`Insert error`, error); + // 如果有错误,将失败的文档添加到失败列表中 + error.writeErrors?.forEach((writeError: any) => { + failedDocuments.push(data[writeError.index]); + }); + console.log('failed', failedDocuments); + } + console.log(startIndex, '==='); + // 对于失败的文档,尝试单独插入 + await MongoDatasetTraining.create(failedDocuments, { session }); - // 对于失败的文档,尝试单独插入 - for await (const item of failedDocuments) { - await MongoDatasetTraining.create(item); + return insertData(startIndex + batchSize, session); + }; + + if (session) { + await insertData(0, session); + } else { + await mongoSessionRun(async (session) => { + await insertData(0, session); + }); } delete filterResult.success;