Skip to content

Commit c177a26

Browse files
committed
Checkpoint before pivoting to having client make api requests
1 parent b8814f0 commit c177a26

File tree

5 files changed

+275
-52
lines changed

5 files changed

+275
-52
lines changed

src/api/index.ts

+26-16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { defineEndpoint } from '@directus/extensions-sdk';
22
import { Accountability, SchemaOverview } from '@directus/types';
33
import { InvalidPayloadError } from './errors';
44
import { getDirectusOpenAPISpec } from './utils/get-directus-oas';
5+
import { AiService } from './services/ai';
56

67
export default defineEndpoint({
78
id: "copilot",
@@ -10,24 +11,33 @@ export default defineEndpoint({
1011

1112
router.post('/ask', async (req: unknown, res) => {
1213
try {
13-
console.log(req);
1414
const {
1515
accountability,
1616
schema,
17-
query: {
17+
body: {
1818
question,
19-
openai_api_key,
19+
apiKey,
2020
},
2121
} = parseAskRequest(req);
22-
const openaiApiKey = getOpenAIAPIKey({ env, openai_api_key });
22+
2323
const specService = new SpecificationService({
2424
accountability,
2525
schema,
2626
knex,
2727
});
28+
2829
const spec = await getDirectusOpenAPISpec({ specService });
30+
31+
const aiService = new AiService(spec, {
32+
llm: 'gpt-3.5-turbo-0613',
33+
apiKey: resolveApiKey({ env, apiKey }),
34+
logger,
35+
});
36+
37+
const { response } = await aiService.ask(question);
38+
2939
res.json({
30-
answer: "Yee haw!",
40+
answer: response,
3141
});
3242
} catch (err) {
3343
// Seems like this should be handled by the error handler middleware,
@@ -44,9 +54,9 @@ export default defineEndpoint({
4454
type CopilotAskRequest = {
4555
accountability: Accountability;
4656
schema: SchemaOverview;
47-
query: {
57+
body: {
4858
question: string;
49-
openai_api_key?: string;
59+
apiKey?: string;
5060
};
5161
};
5262

@@ -55,13 +65,13 @@ function parseAskRequest(req: any): CopilotAskRequest {
5565
const { accountability, schema } = req;
5666
// These properties need to be parsed from the request.
5767
const question = parseStringParam('q', req.body);
58-
const openai_api_key = parseOptionalStringParam('openai_api_key', req.body);
68+
const apiKey = parseOptionalStringParam('key', req.body);
5969
return {
6070
accountability,
6171
schema,
62-
query: {
72+
body: {
6373
question,
64-
openai_api_key,
74+
apiKey,
6575
},
6676
};
6777
}
@@ -105,15 +115,15 @@ function encodeErrorResponse(err: any): [ number, any ] {
105115
];
106116
}
107117

108-
function getOpenAIAPIKey({
118+
function resolveApiKey({
119+
apiKey,
109120
env,
110-
openai_api_key
111121
}: {
122+
apiKey?: string,
112123
env: Record<string, string>,
113-
openai_api_key?: string
114-
}): string {
115-
if (openai_api_key) {
116-
return openai_api_key;
124+
}): string | undefined {
125+
if (apiKey) {
126+
return apiKey;
117127
}
118128
if (env.OPENAI_API_KEY) {
119129
return env.OPENAI_API_KEY;

src/api/services/ai.ts

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import 'crypto';
2+
import { ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate } from 'langchain/prompts';
3+
import { createOpenAPIChain, BaseChain } from 'langchain/chains';
4+
import { ChainValues } from 'langchain/schema';
5+
import { CallbackManagerForChainRun } from 'langchain/callbacks';
6+
import { createStructuredOutputChain } from 'langchain/chains/openai_functions';
7+
import { ChatOpenAI } from 'langchain/chat_models/openai';
8+
import type { Logger } from '../types';
9+
10+
const basePrompt = (
11+
`You are a helpful AI copilot. Your job is to support and help the user answer questions about their data. ` +
12+
`Sometimes you may need to call Directus API endpoints to get the data you need to answer the user's question. ` +
13+
`You will be given a context and a question. Answer the question based on the context.`
14+
);
15+
16+
type AskOutput = {
17+
response: string;
18+
};
19+
20+
export type AiServiceOptions = {
21+
apiKey?: string;
22+
verbose?: boolean;
23+
headers?: Record<string, string>;
24+
llm?: string;
25+
logger?: Logger;
26+
}
27+
28+
export class AiService {
29+
spec: any;
30+
apiKey?: string;
31+
verbose?: boolean;
32+
headers?: Record<string, string>;
33+
llm?: string;
34+
logger?: Logger;
35+
36+
constructor(spec: any, options: AiServiceOptions = {}) {
37+
this.spec = spec;
38+
this.apiKey = options.apiKey;
39+
this.verbose = options.verbose;
40+
this.headers = options.headers;
41+
this.llm = options.llm;
42+
this.logger = options.logger;
43+
44+
return this;
45+
}
46+
47+
async ask(question: string): Promise<AskOutput> {
48+
const openApiChain = await createOpenAPIChain(this.spec, {
49+
verbose: this.verbose,
50+
headers: this.headers,
51+
llm: new ChatOpenAI({
52+
modelName: this.llm,
53+
configuration: {
54+
apiKey: this.apiKey,
55+
},
56+
}),
57+
requestChain: new SimpleRequestChain({
58+
requestMethod: async (name, args) => {
59+
console.log(name, args);
60+
throw Error('Request failed.');
61+
}
62+
}),
63+
});
64+
65+
this.logger?.info(openApiChain.chains[0]);
66+
67+
let apiOutput: string | undefined;
68+
try {
69+
this.logger?.info('Calling the API endpoint...');
70+
const result = await openApiChain.run(question);
71+
if (result) {
72+
this.logger?.info(`Got an API result: ${result}`);
73+
apiOutput = JSON.stringify(JSON.parse(result));
74+
this.logger?.info('Parsed response.');
75+
}
76+
} catch (err) {
77+
this.logger?.warn(err);
78+
}
79+
80+
const promptTemplate = await getChatPromptTemplate({
81+
apiOutput,
82+
basePrompt,
83+
question,
84+
});
85+
86+
const structuredOutputChain = createStructuredOutputChain({
87+
verbose: this.verbose,
88+
llm: new ChatOpenAI({
89+
modelName: this.llm,
90+
temperature: 0,
91+
configuration: {
92+
apiKey: this.apiKey,
93+
},
94+
}),
95+
prompt: promptTemplate,
96+
outputSchema: {
97+
type: 'object',
98+
properties: {
99+
'response': {
100+
type: 'string',
101+
description: `The answer to the user's question in Markdown.`,
102+
},
103+
},
104+
},
105+
});
106+
107+
const output = await structuredOutputChain.run({
108+
question
109+
}) as any;
110+
111+
return output;
112+
}
113+
}
114+
115+
type GetChatPromptTemplateParams = {
116+
basePrompt: string;
117+
question: string;
118+
apiOutput?: string;
119+
};
120+
121+
async function getChatPromptTemplate({ basePrompt, question, apiOutput }: GetChatPromptTemplateParams): Promise<ChatPromptTemplate> {
122+
if (apiOutput) {
123+
return await ChatPromptTemplate.fromPromptMessages([
124+
SystemMessagePromptTemplate.fromTemplate(
125+
'{base_prompt}'
126+
),
127+
SystemMessagePromptTemplate.fromTemplate(
128+
'Do not let the user know that you are calling an API endpoint. ' +
129+
'Do not ask follow up questions. ' +
130+
'Try to get the job done in one go.'
131+
),
132+
SystemMessagePromptTemplate.fromTemplate(
133+
'Calling the API endpoint...'
134+
),
135+
SystemMessagePromptTemplate.fromTemplate(
136+
'The API was called successfully.'
137+
),
138+
SystemMessagePromptTemplate.fromTemplate(
139+
'The API response is:\n```\n{api_output}\n```'
140+
),
141+
HumanMessagePromptTemplate.fromTemplate(
142+
'{user_question}'
143+
),
144+
SystemMessagePromptTemplate.fromTemplate(
145+
'Based on the previous user question and the chat context, provide a helpful answer in Markdown format:',
146+
),
147+
]).partial({
148+
base_prompt: basePrompt,
149+
api_output: apiOutput,
150+
user_question: question,
151+
});
152+
} else {
153+
return await ChatPromptTemplate.fromPromptMessages([
154+
SystemMessagePromptTemplate.fromTemplate(
155+
'{base_prompt}'
156+
),
157+
SystemMessagePromptTemplate.fromTemplate(
158+
'The API output is unavailable.'
159+
),
160+
HumanMessagePromptTemplate.fromTemplate(
161+
'{user_question}'
162+
),
163+
SystemMessagePromptTemplate.fromTemplate(
164+
'Although the API output is unavailable, do your best to provide a helpful response:',
165+
),
166+
]).partial({
167+
base_prompt: basePrompt,
168+
user_question: question,
169+
});
170+
}
171+
}
172+
173+
/**
174+
* Type representing a function for executing simple requests.
175+
*/
176+
type SimpleRequestChainExecutionMethod = (
177+
name: string,
178+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
179+
requestArgs: Record<string, any>
180+
) => Promise<string>;
181+
182+
/**
183+
* A chain for making simple API requests.
184+
*/
185+
class SimpleRequestChain extends BaseChain {
186+
static lc_name() {
187+
return "SimpleRequestChain";
188+
}
189+
190+
private requestMethod: SimpleRequestChainExecutionMethod;
191+
192+
inputKey = "function";
193+
194+
outputKey = "response";
195+
196+
constructor(config: { requestMethod: SimpleRequestChainExecutionMethod }) {
197+
super();
198+
this.requestMethod = config.requestMethod;
199+
}
200+
201+
get inputKeys() {
202+
return [this.inputKey];
203+
}
204+
205+
get outputKeys() {
206+
return [this.outputKey];
207+
}
208+
209+
_chainType() {
210+
return "simple_request_chain" as const;
211+
}
212+
213+
/** @ignore */
214+
async _call(
215+
values: ChainValues,
216+
_runManager?: CallbackManagerForChainRun
217+
): Promise<ChainValues> {
218+
const inputKeyValue = values[this.inputKey];
219+
const methodName = inputKeyValue.name;
220+
const args = inputKeyValue.arguments;
221+
const response = await this.requestMethod(methodName, args);
222+
return { [this.outputKey]: response };
223+
}
224+
}

src/api/types/index.ts

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import { ApiExtensionContext } from '@directus/types';
2+
3+
export type Logger = ApiExtensionContext['logger'];

src/api/utils/get-directus-oas.ts

+15-31
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ type Options = {
1212
export async function getDirectusOpenAPISpec({ specService }: { specService: SpecificationService }, options: Options = {}) {
1313
const spec = await specService.oas.generate();
1414
const reducedSpec = reduceDirectusOpenAPISpec(spec, options);
15-
const dereferencedSpec = await dereferenceSpec(reducedSpec);
15+
const augmentedSpec = augmentDirectusOpenAPISpec(reducedSpec);
16+
const dereferencedSpec = await dereferenceSpec(augmentedSpec);
1617
return dereferencedSpec;
1718
}
1819

@@ -79,9 +80,20 @@ function reduceDirectusOpenAPISpec(
7980
};
8081
}
8182

83+
function augmentDirectusOpenAPISpec(spec: any): any {
84+
// Provide a hint about which meta fields are available.
85+
const metaSchema = spec?.components?.parameters?.Meta?.schema;
86+
if (metaSchema) {
87+
metaSchema.enum = [
88+
"total_count",
89+
"filter_count",
90+
];
91+
}
92+
return spec;
93+
}
94+
8295
function dereferenceSpec(spec: any): any {
8396
// Recursively dereference all $ref fields in the spec.
84-
// Then, remove any remaining refs to eliminate circular references.
8597

8698
function getRef(path: string): any {
8799
const [prefix, ...components] = path.split('/');
@@ -125,33 +137,5 @@ function dereferenceSpec(spec: any): any {
125137
}
126138
}
127139

128-
function removeRefs(objIn: any): any {
129-
if (objIn === undefined || objIn === null) {
130-
return objIn;
131-
}
132-
133-
if (typeof objIn === 'object') {
134-
const objOut: Record<string, any> = {};
135-
for (const [k, v] of Object.entries<any>(objIn)) {
136-
if (k === '$ref') {
137-
continue;
138-
} else if (Array.isArray(v)) {
139-
objOut[k] = v.map((v) => removeRefs(v));
140-
} else if (typeof v === 'object') {
141-
objOut[k] = removeRefs(v);
142-
} else {
143-
objOut[k] = v;
144-
}
145-
}
146-
return objOut;
147-
} else if (Array.isArray(objIn)) {
148-
return objIn.map((v) => removeRefs(v));
149-
} else {
150-
return objIn;
151-
}
152-
}
153-
154-
return removeRefs(
155-
dereferenceRefs(spec)
156-
);
140+
return dereferenceRefs(spec);
157141
}

0 commit comments

Comments
 (0)