Skip to content

Commit

Permalink
refactored ai backend code for composability
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhravya committed Apr 3, 2024
1 parent a00bb30 commit 4380ea8
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 141 deletions.
7 changes: 7 additions & 0 deletions apps/cf-ai-backend/src/env.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
interface Env {
VECTORIZE_INDEX: VectorizeIndex;
AI: Fetcher;
SECURITY_KEY: string;
OPENAI_API_KEY: string;
GOOGLE_AI_API_KEY: string;
}
154 changes: 13 additions & 141 deletions apps/cf-ai-backend/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,18 @@ import {
} from "@langchain/cloudflare";
import { OpenAIEmbeddings } from "./OpenAIEmbedder";
import { GoogleGenerativeAI } from "@google/generative-ai";

export interface Env {
VECTORIZE_INDEX: VectorizeIndex;
AI: Fetcher;
SECURITY_KEY: string;
OPENAI_API_KEY: string;
GOOGLE_AI_API_KEY: string;
}

import routeMap from "./routes";

function isAuthorized(request: Request, env: Env): boolean {
return request.headers.get('X-Custom-Auth-Key') === env.SECURITY_KEY;
}

export default {
async fetch(request: Request, env: Env) {
async fetch(request: Request, env: Env, ctx: ExecutionContext) {
if (!isAuthorized(request, env)) {
return new Response('Unauthorized', { status: 401 });
}

const pathname = new URL(request.url).pathname;
const embeddings = new OpenAIEmbeddings({
apiKey: env.OPENAI_API_KEY,
modelName: 'text-embedding-3-small',
Expand All @@ -40,143 +31,24 @@ export default {
});

const genAI = new GoogleGenerativeAI(env.GOOGLE_AI_API_KEY);
const model = genAI.getGenerativeModel({ model: "gemini-pro" });

// TODO: Add /chat endpoint to chat with the AI in a conversational manner
if (pathname === "/add" && request.method === "POST") {

const body = await request.json() as {
pageContent: string,
title?: string,
description?: string,
url: string,
user: string
};


if (!body.pageContent || !body.url) {
return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 });
}
const newPageContent = `Title: ${body.title}\nDescription: ${body.description}\nURL: ${body.url}\nContent: ${body.pageContent}`


await store.addDocuments([
{
pageContent: newPageContent,
metadata: {
title: body.title ?? "",
description: body.description ?? "",
url: body.url,
user: body.user,
},
},
], {
ids: [`${body.url}`]
})

return new Response(JSON.stringify({ message: "Document Added" }), { status: 200 });
}

else if (pathname === "/query" && request.method === "GET") {
const queryparams = new URL(request.url).searchParams;
const query = queryparams.get("q");
const topK = parseInt(queryparams.get("topK") ?? "5");
const user = queryparams.get("user")

const sourcesOnly = (queryparams.get("sourcesOnly") ?? "false")

if (!user) {
return new Response(JSON.stringify({ message: "Invalid User" }), { status: 400 });
}

if (!query) {
return new Response(JSON.stringify({ message: "Invalid Query" }), { status: 400 });
}

const filter: VectorizeVectorMetadataFilter = {
user: {
$eq: user
}
}

const queryAsVector = await embeddings.embedQuery(query);

const resp = await env.VECTORIZE_INDEX.query(queryAsVector, {
topK,
filter
});

if (resp.count === 0) {
return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 });
}

const highScoreIds = resp.matches.filter(({ score }) => score > 0.3).map(({ id }) => id)

if (sourcesOnly === "true") {
return new Response(JSON.stringify({ ids: highScoreIds }), { status: 200 });
}

const vec = await env.VECTORIZE_INDEX.getByIds(highScoreIds)

if (vec.length === 0 || !vec[0].metadata) {
return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 });
}
const model = genAI.getGenerativeModel({ model: "gemini-pro" });

const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n");
const url = new URL(request.url);
const path = url.pathname;
const method = request.method.toUpperCase();

const prompt = `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${preparedContext} \nAnswer this question based on the context. Question: ${query}\nAnswer:`
const output = await model.generateContentStream(prompt);
const routeHandlers = routeMap.get(path);

const response = new Response(
new ReadableStream({
async start(controller) {
const converter = new TextEncoder();
for await (const chunk of output.stream) {
const chunkText = await chunk.text();
const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
controller.enqueue(encodedChunk);
}
const doneChunk = converter.encode("data: [DONE]");
controller.enqueue(doneChunk);
controller.close();
}
})
);
return response;
if (!routeHandlers) {
return new Response('Not Found', { status: 404 });
}

else if (pathname === "/ask" && request.method === "POST") {
const body = await request.json() as {
query: string
};

if (!body.query) {
return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 });
}
const handler = routeHandlers[method];

const prompt = `You are an agent that answers a question based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:`
const output = await model.generateContentStream(prompt);

const response = new Response(
new ReadableStream({
async start(controller) {
const converter = new TextEncoder();
for await (const chunk of output.stream) {
const chunkText = await chunk.text();
console.log(chunkText);
const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
controller.enqueue(encodedChunk);
}
const doneChunk = converter.encode("data: [DONE]");
controller.enqueue(doneChunk);
controller.close();
}
})
);
return response;
if (!handler) {
return new Response('Method Not Allowed', { status: 405 });
}

return new Response(JSON.stringify({ message: "Invalid Request" }), { status: 400 });

return await handler(request, store, embeddings, model, env, ctx);
},
};
29 changes: 29 additions & 0 deletions apps/cf-ai-backend/src/routes.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import { CloudflareVectorizeStore } from '@langchain/cloudflare';
import * as apiAdd from './routes/add';
import * as apiQuery from "./routes/query"
import * as apiAsk from "./routes/ask"
import { OpenAIEmbeddings } from './OpenAIEmbedder';
import { GenerativeModel } from '@google/generative-ai';
import { Request } from '@cloudflare/workers-types';


type RouteHandler = (request: Request, store: CloudflareVectorizeStore, embeddings: OpenAIEmbeddings, model: GenerativeModel, env: Env, ctx?: ExecutionContext) => Promise<Response>;

const routeMap = new Map<string, Record<string, RouteHandler>>();

routeMap.set('/add', {
POST: apiAdd.POST,
});

routeMap.set('/query', {
GET: apiQuery.GET,
});

routeMap.set('/ask', {
POST: apiAsk.POST,
});

// Add more route mappings as needed
// routeMap.set('/api/otherRoute', { ... });

export default routeMap;
36 changes: 36 additions & 0 deletions apps/cf-ai-backend/src/routes/add.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { Request } from "@cloudflare/workers-types";
import { type CloudflareVectorizeStore } from "@langchain/cloudflare";

export async function POST(request: Request, store: CloudflareVectorizeStore) {
const body = await request.json() as {
pageContent: string,
title?: string,
description?: string,
category?: string,
url: string,
user: string
};

if (!body.pageContent || !body.url) {
return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 });
}
const newPageContent = `Title: ${body.title}\nDescription: ${body.description}\nURL: ${body.url}\nContent: ${body.pageContent}`


await store.addDocuments([
{
pageContent: newPageContent,
metadata: {
title: body.title ?? "",
description: body.description ?? "",
category: body.category ?? "",
url: body.url,
user: body.user,
},
},
], {
ids: [`${body.url}`]
})

return new Response(JSON.stringify({ message: "Document Added" }), { status: 200 });
}
35 changes: 35 additions & 0 deletions apps/cf-ai-backend/src/routes/ask.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { GenerativeModel } from "@google/generative-ai";
import { OpenAIEmbeddings } from "../OpenAIEmbedder";
import { CloudflareVectorizeStore } from "@langchain/cloudflare";
import { Request } from "@cloudflare/workers-types";

export async function POST(request: Request, _: CloudflareVectorizeStore, embeddings: OpenAIEmbeddings, model: GenerativeModel, env?: Env) {
const body = await request.json() as {
query: string
};

if (!body.query) {
return new Response(JSON.stringify({ message: "Invalid Page Content" }), { status: 400 });
}

const prompt = `You are an agent that answers a question based on the query. don't say 'based on the context'.\n\n Context:\n${body.query} \nAnswer this question based on the context. Question: ${body.query}\nAnswer:`
const output = await model.generateContentStream(prompt);

const response = new Response(
new ReadableStream({
async start(controller) {
const converter = new TextEncoder();
for await (const chunk of output.stream) {
const chunkText = await chunk.text();
console.log(chunkText);
const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
controller.enqueue(encodedChunk);
}
const doneChunk = converter.encode("data: [DONE]");
controller.enqueue(doneChunk);
controller.close();
}
})
);
return response;
}
72 changes: 72 additions & 0 deletions apps/cf-ai-backend/src/routes/query.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { GenerativeModel } from "@google/generative-ai";
import { OpenAIEmbeddings } from "../OpenAIEmbedder";
import { CloudflareVectorizeStore } from "@langchain/cloudflare";
import { Request } from "@cloudflare/workers-types";

export async function GET(request: Request, _: CloudflareVectorizeStore, embeddings: OpenAIEmbeddings, model: GenerativeModel, env?: Env) {
const queryparams = new URL(request.url).searchParams;
const query = queryparams.get("q");
const topK = parseInt(queryparams.get("topK") ?? "5");
const user = queryparams.get("user")

const sourcesOnly = (queryparams.get("sourcesOnly") ?? "false")

if (!user) {
return new Response(JSON.stringify({ message: "Invalid User" }), { status: 400 });
}

if (!query) {
return new Response(JSON.stringify({ message: "Invalid Query" }), { status: 400 });
}

const filter: VectorizeVectorMetadataFilter = {
user: {
$eq: user
}
}

const queryAsVector = await embeddings.embedQuery(query);

const resp = await env!.VECTORIZE_INDEX.query(queryAsVector, {
topK,
filter
});

if (resp.count === 0) {
return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 });
}

const highScoreIds = resp.matches.filter(({ score }) => score > 0.3).map(({ id }) => id)

if (sourcesOnly === "true") {
return new Response(JSON.stringify({ ids: highScoreIds }), { status: 200 });
}

const vec = await env!.VECTORIZE_INDEX.getByIds(highScoreIds)

if (vec.length === 0 || !vec[0].metadata) {
return new Response(JSON.stringify({ message: "No Results Found" }), { status: 400 });
}

const preparedContext = vec.slice(0, 3).map(({ metadata }) => `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`).join("\n\n");

const prompt = `You are an agent that summarizes a page based on the query. Be direct and concise, don't say 'based on the context'.\n\n Context:\n${preparedContext} \nAnswer this question based on the context. Question: ${query}\nAnswer:`
const output = await model.generateContentStream(prompt);

const response = new Response(
new ReadableStream({
async start(controller) {
const converter = new TextEncoder();
for await (const chunk of output.stream) {
const chunkText = await chunk.text();
const encodedChunk = converter.encode("data: " + JSON.stringify({ "response": chunkText }) + "\n\n");
controller.enqueue(encodedChunk);
}
const doneChunk = converter.encode("data: [DONE]");
controller.enqueue(doneChunk);
controller.close();
}
})
);
return response;
}

0 comments on commit 4380ea8

Please sign in to comment.