Skip to content

Commit

Permalink
perf: push data to training queue
Browse files Browse the repository at this point in the history
  • Loading branch information
c121914yu committed Sep 10, 2024
1 parent 3255cf0 commit ec1ccfc
Showing 1 changed file with 60 additions and 48 deletions.
108 changes: 60 additions & 48 deletions packages/service/core/dataset/training/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { ClientSession } from '../../../common/mongo';
import { getLLMModel, getVectorModel } from '../../ai/model';
import { addLog } from '../../../common/system/log';
import { getCollectionWithDataset } from '../controller';
import { mongoSessionRun } from '../../../common/mongo/sessionRun';

export const lockTrainingDataByTeamId = async (teamId: string): Promise<any> => {
try {
Expand Down Expand Up @@ -64,7 +65,7 @@ export async function pushDataListToTrainingQueue({
vectorModel: string;
session?: ClientSession;
} & PushDatasetDataProps): Promise<PushDatasetDataResponse> {
const checkModelValid = async () => {
const { model, maxToken, weight } = await (async () => {
const agentModelData = getLLMModel(agentModel);
if (!agentModelData) {
return Promise.reject(`File model ${agentModel} is inValid`);
Expand All @@ -91,9 +92,16 @@ export async function pushDataListToTrainingQueue({
}

return Promise.reject(`Training mode "${trainingMode}" is inValid`);
};
})();

const { model, maxToken, weight } = await checkModelValid();
// filter repeat or equal content
const set = new Set();
const filterResult: Record<string, PushDatasetDataChunkProps[]> = {
success: [],
overToken: [],
repeat: [],
error: []
};

// format q and a, remove empty char
data.forEach((item) => {
Expand All @@ -108,19 +116,8 @@ export async function pushDataListToTrainingQueue({
};
})
.filter(Boolean);
});

// filter repeat or equal content
const set = new Set();
const filterResult: Record<string, PushDatasetDataChunkProps[]> = {
success: [],
overToken: [],
repeat: [],
error: []
};

// filter repeat content
data.forEach((item) => {
// filter repeat content
if (!item.q) {
filterResult.error.push(item);
return;
Expand Down Expand Up @@ -150,40 +147,55 @@ export async function pushDataListToTrainingQueue({
const failedDocuments: PushDatasetDataChunkProps[] = [];

// 使用 insertMany 批量插入
try {
await MongoDatasetTraining.insertMany(
filterResult.success.map((item) => ({
teamId,
tmbId,
datasetId,
collectionId,
billId,
mode: trainingMode,
prompt,
model,
q: item.q,
a: item.a,
chunkIndex: item.chunkIndex ?? 0,
weight: weight ?? 0,
indexes: item.indexes
})),
{
session,
ordered: false
}
);
} catch (error: any) {
addLog.error(`Insert error`, error);
// 如果有错误,将失败的文档添加到失败列表中
error.writeErrors?.forEach((writeError: any) => {
failedDocuments.push(data[writeError.index]);
});
console.log('failed', failedDocuments);
}
const batchSize = 200;
const insertData = async (startIndex: number, session: ClientSession) => {
const list = filterResult.success.slice(startIndex, startIndex + batchSize);

if (list.length === 0) return;

try {
await MongoDatasetTraining.insertMany(
list.map((item) => ({
teamId,
tmbId,
datasetId,
collectionId,
billId,
mode: trainingMode,
prompt,
model,
q: item.q,
a: item.a,
chunkIndex: item.chunkIndex ?? 0,
weight: weight ?? 0,
indexes: item.indexes
})),
{
session,
ordered: true
}
);
} catch (error: any) {
addLog.error(`Insert error`, error);
// 如果有错误,将失败的文档添加到失败列表中
error.writeErrors?.forEach((writeError: any) => {
failedDocuments.push(data[writeError.index]);
});
console.log('failed', failedDocuments);
}
console.log(startIndex, '===');
// 对于失败的文档,尝试单独插入
await MongoDatasetTraining.create(failedDocuments, { session });

// 对于失败的文档,尝试单独插入
for await (const item of failedDocuments) {
await MongoDatasetTraining.create(item);
return insertData(startIndex + batchSize, session);
};

if (session) {
await insertData(0, session);
} else {
await mongoSessionRun(async (session) => {
await insertData(0, session);
});
}

delete filterResult.success;
Expand Down

0 comments on commit ec1ccfc

Please sign in to comment.