diff --git a/packages/orama/src/methods/answer-session.ts b/packages/orama/src/methods/answer-session.ts index 1b42bbfde..8a2a27263 100644 --- a/packages/orama/src/methods/answer-session.ts +++ b/packages/orama/src/methods/answer-session.ts @@ -3,21 +3,21 @@ import type { AnyDocument, AnyOrama, Nullable, OramaPluginSync, SearchParams, Re import { createError } from "../errors.js" import { search } from "./search.js" -type GenericContext = +export type GenericContext = | string | object -type MessageRole = +export type MessageRole = | 'system' | 'user' | 'assistant' -type Message = { +export type Message = { role: MessageRole content: string } -type Interaction = { +export type Interaction = { interactionId: string query: string response: string @@ -29,11 +29,11 @@ type Interaction = { errorMessage: Nullable } -type AnswerSessionEvents = { +export type AnswerSessionEvents = { onStateChange?: (state: Interaction[]) => void } -type IAnswerSessionConfig = { +export type IAnswerSessionConfig = { conversationID?: string systemPrompt?: string userContext?: GenericContext @@ -41,9 +41,9 @@ type IAnswerSessionConfig = { events?: AnswerSessionEvents } -type AskParams = SearchParams +export type AskParams = SearchParams -type RegenerateLastParams = { +export type RegenerateLastParams = { stream: boolean } @@ -60,8 +60,8 @@ export class AnswerSession { private conversationID: string private messages: Message[] = [] private events: AnswerSessionEvents - private state: Interaction[] = [] private initPromise?: Promise + public state: Interaction[] = [] constructor(db: AnyOrama, config: IAnswerSessionConfig) { this.db = db @@ -215,7 +215,7 @@ export class AnswerSession { throw createError('PLUGIN_SECURE_PROXY_NOT_FOUND') } - const pluginExtras = plugin.extra as { proxy: OramaProxy, pluginParams: { models: { chat: ChatModel } } } + const pluginExtras = plugin.extra as { proxy: OramaProxy, pluginParams: { chat: { model: ChatModel } } } this.proxy = pluginExtras.proxy @@ -223,8 +223,8 @@ export class AnswerSession { this.messages.push({ role: 'system', content: this.config.systemPrompt }) } - if (pluginExtras?.pluginParams?.models?.chat) { - this.chatModel = pluginExtras.pluginParams.models.chat + if (pluginExtras?.pluginParams?.chat?.model) { + this.chatModel = pluginExtras.pluginParams.chat.model } else { throw createError('PLUGIN_SECURE_PROXY_MISSING_CHAT_MODEL') } diff --git a/packages/orama/src/types.ts b/packages/orama/src/types.ts index 2af85fde4..4e47a7a9d 100644 --- a/packages/orama/src/types.ts +++ b/packages/orama/src/types.ts @@ -7,6 +7,18 @@ import { Sorter } from './components/sorter.js' import { Language } from './components/tokenizer/languages.js' import { Point } from './trees/bkd.js' +export type { + IAnswerSessionConfig, + AnswerSession, + AnswerSessionEvents, + AskParams, + GenericContext, + Interaction, + Message, + MessageRole, + RegenerateLastParams +} from './methods/answer-session.js' + export { MODE_FULLTEXT_SEARCH, MODE_HYBRID_SEARCH, MODE_VECTOR_SEARCH } from './constants.js' export type { DefaultTokenizer } from './components/tokenizer/index.js' diff --git a/packages/switch/package.json b/packages/switch/package.json index ff61143ee..4edf08806 100644 --- a/packages/switch/package.json +++ b/packages/switch/package.json @@ -16,7 +16,7 @@ "files": ["dist"], "scripts": { "build": "tsup", - "test": "tsx --test tests/search.test.ts" + "test": "tsx --test tests/answer-session.test.ts && tsx --test tests/search.test.ts" }, "keywords": ["orama", "orama cloud"], "type": "module", @@ -31,6 +31,7 @@ }, "license": "Apache-2.0", "devDependencies": { + "@orama/plugin-secure-proxy": "workspace:*", "tsup": "^7.2.0", "tsx": "^4.19.0" } diff --git a/packages/switch/src/index.ts b/packages/switch/src/index.ts index 95779aba4..254af6000 100644 --- a/packages/switch/src/index.ts +++ b/packages/switch/src/index.ts @@ -1,6 +1,6 @@ -import type { AnyOrama, Results, SearchParams, Nullable } from '@orama/orama' -import { search } from '@orama/orama' -import { OramaClient, ClientSearchParams } from '@oramacloud/client' +import type { AnyOrama, Results, SearchParams, Nullable, IAnswerSessionConfig as OSSAnswerSessionConfig } from '@orama/orama' +import { search, AnswerSession as OSSAnswerSession } from '@orama/orama' +import { OramaClient, ClientSearchParams, AnswerSessionParams as CloudAnswerSessionConfig, AnswerSession as CloudAnswerSession } from '@oramacloud/client' export type OramaSwitchClient = AnyOrama | OramaClient @@ -13,6 +13,7 @@ export type SearchConfig = { } export class Switch { + private invalidClientError = 'Invalid client. Expected either an OramaClient or an Orama OSS database.' client: OramaSwitchClient clientType: ClientType isCloud: boolean = false @@ -28,7 +29,7 @@ export class Switch { this.clientType = 'oss' this.isOSS = true } else { - throw new Error('Invalid client. Expected either an OramaClient or an Orama OSS database.') + throw new Error(this.invalidClientError) } } @@ -42,4 +43,24 @@ export class Switch { return search(this.client as AnyOrama, params as SearchParams) as Promise>> } } -} + + createAnswerSession(params: T extends OramaClient ? CloudAnswerSessionConfig : OSSAnswerSessionConfig): T extends OramaClient ? CloudAnswerSession : OSSAnswerSession { + if (this.isCloud) { + const p = params as CloudAnswerSessionConfig + return (this.client as OramaClient).createAnswerSession(p) as unknown as T extends OramaClient ? CloudAnswerSession : OSSAnswerSession + } + + if (this.isOSS) { + const p = params as OSSAnswerSessionConfig + return new OSSAnswerSession(this.client as AnyOrama, { + conversationID: p.conversationID, + initialMessages: p.initialMessages, + events: p.events, + userContext: p.userContext, + systemPrompt: p.systemPrompt, + }) as unknown as T extends OramaClient ? CloudAnswerSession : OSSAnswerSession + } + + throw new Error(this.invalidClientError) + } +} \ No newline at end of file diff --git a/packages/switch/tests/answer-session.test.ts b/packages/switch/tests/answer-session.test.ts new file mode 100644 index 000000000..4fd00a8b6 --- /dev/null +++ b/packages/switch/tests/answer-session.test.ts @@ -0,0 +1,71 @@ +import { test } from 'node:test' +import assert from 'node:assert/strict' +import { OramaClient } from '@oramacloud/client' +import { pluginSecureProxy } from '@orama/plugin-secure-proxy' +import { create, insertMultiple } from '@orama/orama' +import { Switch } from '../src/index.js' + +const CLOUD_URL = process.env.ORAMA_CLOUD_E2E_URL +const CLOUD_API_KEY = process.env.ORAMA_CLOUD_E2E_API_KEY +const SECURE_PROXY_API_KEY = process.env.ORAMA_SECURE_PROXY_API_KEY + +if (!CLOUD_URL || !CLOUD_API_KEY) { + console.log( + 'Skipping Orama Switch remote client test since ORAMA_CLOUD_E2E_URL and ORAMA_CLOUD_E2E_API_KEY are not set' + ) + process.exit(0) +} + +test('local client', async () => { + const db = await create({ + schema: { + name: 'string', + age: 'number', + embeddings: 'vector[384]' + } as const, + plugins: [ + pluginSecureProxy({ + apiKey: SECURE_PROXY_API_KEY!, + chat: { + model: 'openai/gpt-4' + }, + embeddings: { + defaultProperty: 'embeddings', + model: 'openai/text-embedding-3-small', + }, + }) + ] + }) + + await insertMultiple(db, [ + { name: 'Alice', age: 30 }, + { name: 'Bob', age: 40 }, + { name: 'Charlie', age: 50 } + ]) + + const answerSession = new Switch(db).createAnswerSession({ + systemPrompt: `You're an AI agent used to greet people. You will receive a name and will have to proceed generating a greeting message for that name. Use your fantasy.`, + }) + + await answerSession.ask({ term: 'Bob', where: { age: { eq: 40 } } }) + + const state = answerSession.state + + assert(state.length === 1) + assert(state[0].query === 'Bob') +}) + +test('remote client', async () => { + const cloudClient = new OramaClient({ + api_key: CLOUD_API_KEY, + endpoint: CLOUD_URL + }) + + const answerSession = new Switch(cloudClient).createAnswerSession({}) + await answerSession.ask({ term: 'What is Orama?' }) + + const state = answerSession.state + + assert(state.length === 1) + assert(state[0].query === 'What is Orama?') +}) \ No newline at end of file diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4f654ce40..d9c9743d3 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -820,6 +820,9 @@ importers: specifier: 1.3.15 version: 1.3.15(encoding@0.1.13)(typescript@5.6.2)(zod@3.23.8) devDependencies: + '@orama/plugin-secure-proxy': + specifier: workspace:* + version: link:../plugin-secure-proxy tsup: specifier: ^7.2.0 version: 7.2.0(@swc/core@1.3.27)(postcss@8.4.47)(ts-node@10.9.1(@swc/core@1.3.27)(@types/node@20.11.19)(typescript@5.6.2))(typescript@5.6.2)