Skip to content

Commit

Permalink
feat: use prompt template
Browse files Browse the repository at this point in the history
  • Loading branch information
linonetwo committed Apr 13, 2024
1 parent 47e0535 commit 08f7cb4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
5 changes: 3 additions & 2 deletions src/services/languageModel/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ export class LanguageModel implements ILanguageModelService {

public runLanguageModel$(runner: LanguageModelRunner.llamaCpp, options: IRunLLAmaOptions): Observable<ILanguageModelAPIResponse>;
public runLanguageModel$(runner: LanguageModelRunner, options: IRunLLAmaOptions): Observable<ILanguageModelAPIResponse> {
const { id: conversationID, completionOptions, loadConfig: config } = options;
const { id: conversationID, loadConfig: config } = options;
this.updateModelLoaded({ [runner]: null });
return new Observable<ILanguageModelAPIResponse>((subscriber) => {
const runLanguageModelObserverIIFE = async () => {
Expand All @@ -158,9 +158,10 @@ export class LanguageModel implements ILanguageModelService {
} else {
// load and run model
const texts = { timeout: i18n.t('LanguageModel.GenerationTimeout'), disposed: i18n.t('LanguageModel.ModelDisposed') };
logger.info(options.sessionOptions?.systemPrompt ?? '', { tag: 'options.sessionOptions?.systemPrompt' });
switch (runner) {
case LanguageModelRunner.llamaCpp: {
observable = worker.runLLama({ completionOptions, loadConfig: { ...config, modelPath }, conversationID }, texts);
observable = worker.runLLama({ ...options, loadConfig: { ...config, modelPath }, conversationID }, texts);
break;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/services/languageModel/interface.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { LanguageModelChannel } from '@/constants/channels';
import { ProxyPropertyType } from 'electron-ipc-cat/common';
import type { LLamaChatPromptOptions, LlamaModelOptions } from 'node-llama-cpp';
import type { LLamaChatPromptOptions, LlamaChatSessionOptions, LlamaModelOptions, JinjaTemplateChatWrapperOptions } from 'node-llama-cpp';
import type { Observable } from 'rxjs';

export enum LanguageModelRunner {
Expand Down Expand Up @@ -60,6 +60,8 @@ export interface IRunLLAmaOptions extends ILLMResultBase {
* Without generating text.
*/
loadModelOnly?: boolean;
sessionOptions?: Pick<LlamaChatSessionOptions, 'systemPrompt'>;
templates?: Partial<Pick<JinjaTemplateChatWrapperOptions, 'template' | 'systemRoleName'>>;
}

/**
Expand Down
18 changes: 10 additions & 8 deletions src/services/languageModel/llmWorker/llamaCpp.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { LLAMA_PREBUILT_BINS_DIRECTORY } from './preload';

import debounce from 'lodash/debounce';
import { getLlama, Llama, LlamaChatSession, LlamaContext, LlamaContextSequence, LlamaModel, LlamaModelOptions } from 'node-llama-cpp';
import { getLlama, Llama, LlamaChatSession, LlamaContext, LlamaContextSequence, LlamaModel, LlamaModelOptions, JinjaTemplateChatWrapper } from 'node-llama-cpp';
import { Observable, Subscriber } from 'rxjs';
import { ILanguageModelWorkerResponse, IRunLLAmaOptions } from '../interface';
import { DEFAULT_TIMEOUT_DURATION } from './constants';
Expand Down Expand Up @@ -90,14 +90,10 @@ export async function unloadLLama() {
}
const runnerAbortControllers = new Map<string, AbortController>();
export function runLLama(
options: {
completionOptions: IRunLLAmaOptions['completionOptions'];
conversationID: IRunLLAmaOptions['id'];
loadConfig: IRunLLAmaOptions['loadConfig'] & Pick<LlamaModelOptions, 'modelPath'>;
},
options: IRunLLAmaOptions & { conversationID: IRunLLAmaOptions['id']; loadConfig: IRunLLAmaOptions['loadConfig'] & Pick<LlamaModelOptions, 'modelPath'> },
texts: { disposed: string; timeout: string },
): Observable<ILanguageModelWorkerResponse> {
const { conversationID, completionOptions, loadConfig } = options;
const { conversationID, completionOptions, sessionOptions, loadConfig, templates } = options;

const loggerCommonMeta = { level: 'debug' as const, meta: { function: 'llmWorker.runLLama' }, id: conversationID };
return new Observable<ILanguageModelWorkerResponse>((subscriber) => {
Expand Down Expand Up @@ -141,9 +137,15 @@ export function runLLama(
if (contextSequenceInstance === undefined) {
contextSequenceInstance = contextInstance.getSequence();
}
const chatWrapper = new JinjaTemplateChatWrapper({
template: "{{ 'System: ' + systemPrompt if systemPrompt else '' }}{{ 'User: ' + userInput if userInput else '' }}",
...templates,
});
const session = new LlamaChatSession({
contextSequence: contextSequenceInstance,
autoDisposeSequence: false,
systemPrompt: sessionOptions?.systemPrompt,
chatWrapper,
});
await session.prompt(completionOptions.prompt, {
...completionOptions,
Expand All @@ -167,7 +169,7 @@ export function runLLama(
subscriber.next({ message: 'createCompletion completed', ...loggerCommonMeta });
} catch (error) {
if ((error as Error).message.includes('aborted')) {
console.info('abortLLama', conversationID);
console.info(`abortLLama ${(error as Error).message}`, conversationID);
} else {
runnerAbortControllers.delete(conversationID);
subscriber.error(error);
Expand Down

0 comments on commit 08f7cb4

Please sign in to comment.