Skip to content

Commit

Permalink
[fix] Replace nativeEnum in ChromaDB plugin; Improve metadata fetch…
Browse files Browse the repository at this point in the history
…ing (#1255)
  • Loading branch information
ssbushi authored Nov 12, 2024
1 parent 8b37692 commit a1f60d2
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 15 deletions.
2 changes: 1 addition & 1 deletion js/plugins/chroma/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"license": "Apache-2.0",
"dependencies": {
"ts-md5": "^1.3.1",
"chromadb": "^1.7.3"
"chromadb": "1.8.1"
},
"peerDependencies": {
"genkit": "workspace:*"
Expand Down
55 changes: 51 additions & 4 deletions js/plugins/chroma/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import {
ChromaClient,
Collection,
CollectionMetadata,
Embeddings,
IEmbeddingFunction,
IncludeEnum,
Metadata,
Expand All @@ -42,8 +43,13 @@ export { IncludeEnum };
const WhereSchema: z.ZodType<Where> = z.any();
const WhereDocumentSchema: z.ZodType<WhereDocument> = z.any();

const IncludeOptionSchema = z
.array(z.enum(['documents', 'embeddings', 'metadatas', 'distances']))
.optional();
type IncludeOption = z.infer<typeof IncludeOptionSchema>;

const ChromaRetrieverOptionsSchema = CommonRetrieverOptionsSchema.extend({
include: z.array(z.nativeEnum(IncludeEnum)).optional(),
include: IncludeOptionSchema,
where: WhereSchema.optional(),
whereDocument: WhereDocumentSchema.optional(),
});
Expand Down Expand Up @@ -142,21 +148,23 @@ export function chromaRetriever<EmbedderCustomOptions extends z.ZodTypeAny>(
});
const results = await collection.query({
nResults: options?.k,
include: options?.include,
include: getIncludes(options?.include),
where: options?.where,
whereDocument: options?.whereDocument,
queryEmbeddings: embedding,
});

const documents = results.documents[0];
const metadatas = results.metadatas[0];
const metadatas = results.metadatas;
const embeddings = results.embeddings;
const distances = results.distances;

const combined = documents
.map((d, i) => {
if (d !== null) {
return {
document: d,
metadata: metadatas[i] ?? undefined,
metadata: constructMetadata(i, metadatas, embeddings, distances),
};
}
return undefined;
Expand All @@ -174,6 +182,45 @@ export function chromaRetriever<EmbedderCustomOptions extends z.ZodTypeAny>(
);
}

/**
* Helper method to compute effective Include enum. It always
* includes documents
*/
function getIncludes(includes: IncludeOption): IncludeEnum[] | undefined {
if (!includes) {
// Default behaviour
return undefined;
}

// Always include documents
let effectiveIncludes = [IncludeEnum.Documents];
effectiveIncludes = effectiveIncludes.concat(includes as IncludeEnum[]);
const includesSet = new Set(effectiveIncludes);
return Array.from(includesSet);
}

/**
* Helper method to construct metadata, including the optional {@link IncludeEnum} passed in config.
*/
function constructMetadata(
i: number,
metadatas: (Metadata | null)[][],
embeddings: Embeddings[] | null,
distances: number[][] | null
): any {
var fullMetadata: Record<string, any> = {};
if (metadatas && metadatas[i]) {
fullMetadata.metadata = metadatas[i];
}
if (embeddings) {
fullMetadata.embedding = embeddings[i];
}
if (distances) {
fullMetadata.distances = distances[i];
}
return fullMetadata;
}

/**
* Configures a Chroma indexer.
*/
Expand Down
45 changes: 35 additions & 10 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a1f60d2

Please sign in to comment.