diff --git a/Cargo.toml b/Cargo.toml
index 768cd3e..df72e78 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -24,7 +24,7 @@ anyhow = "1.0.98"
wasm-bindgen = { version = "0.2.100", optional = true }
serde-wasm-bindgen = { version = "0.6.5", optional = true }
console_error_panic_hook = { version = "0.1.7", optional = true }
-getrandom = { version = "0.3", features = ["wasm_js"], optional = true }
+js-sys = { version = "0.3", optional = true }
hf-hub = { version = "0.4.3", optional = true, default-features = false, features = [
"ureq",
"rustls-tls",
@@ -35,6 +35,9 @@ pyo3 = { version = "0.25.1", optional = true, features = ["extension-module"] }
ndarray = { version = "0.16.1", optional = true }
numpy = { version = "0.25.0", optional = true }
+[target.'cfg(target_family = "wasm")'.dependencies]
+getrandom = { version = "0.2", features = ["js"] }
+
[features]
default = ["tokenizers/onig", "hf-hub"]
@@ -42,7 +45,7 @@ wasm = [
"dep:wasm-bindgen",
"dep:serde-wasm-bindgen",
"dep:console_error_panic_hook",
- "dep:getrandom",
+ "dep:js-sys",
"tokenizers/unstable_wasm",
]
diff --git a/docs/index.html b/docs/index.html
index d3599e8..b6c271a 100644
--- a/docs/index.html
+++ b/docs/index.html
@@ -892,6 +892,7 @@
Pooled Embeddings
"lightonai/GTE-ModernColBERT-v1",
"lightonai/colbertv2.0",
"lightonai/Reason-ModernColBERT",
+ "mixedbread-ai/mxbai-edge-colbert-v0-17m",
],
REQUIRED_FILES: [
'tokenizer.json', 'model.safetensors', 'config.json',
@@ -1006,6 +1007,28 @@ Pooled Embeddings
pooledScore: document.getElementById('pooled-score'),
},
+ /**
+ * Retry helper with exponential backoff
+ */
+ async retryWithBackoff(fn, maxRetries = 3, baseDelay = 100) {
+ for (let i = 0; i < maxRetries; i++) {
+ try {
+ return await fn();
+ } catch (error) {
+ const isLastAttempt = i === maxRetries - 1;
+ const isTableGrowError = error.message && error.message.includes('Table.grow');
+
+ if (isLastAttempt || !isTableGrowError) {
+ throw error;
+ }
+
+ const delay = baseDelay * Math.pow(2, i);
+ console.warn(`WASM init failed (attempt ${i + 1}/${maxRetries}), retrying in ${delay}ms...`, error.message);
+ await new Promise(resolve => setTimeout(resolve, delay));
+ }
+ }
+ },
+
/**
* Initializes the application, sets up event listeners, and loads the initial model.
*/
@@ -1016,11 +1039,14 @@ Pooled Embeddings
hljs.highlightAll();
try {
- await init();
- this.loadModel(this.state.currentModelName);
+ await this.retryWithBackoff(() => init(), 5, 200);
+ console.log("✅ WASM module initialized successfully");
+ // Add a small delay to ensure WASM is fully ready
+ await new Promise(resolve => setTimeout(resolve, 100));
+ await this.loadModel(this.state.currentModelName);
} catch (e) {
- console.error("Fatal: Failed to initialize WASM module.", e);
- this.updateAllStatuses(`Error: Could not initialize the application.`, 'error');
+ console.error("Fatal: Failed to initialize WASM module after retries.", e);
+ this.updateAllStatuses(`Error: Could not initialize the application. Please refresh the page.`, 'error');
}
},
@@ -1149,9 +1175,64 @@ Pooled Embeddings
console.warn(`Unable to download the model.`, e);
}
this.updateAllStatuses(`Initializing ${modelRepo}...`, 'loading');
+
+ // Try to fetch all Dense layers (1_Dense, 2_Dense, 3_Dense, ...) until we don't find one
+ const projectionLayers = [];
+ let i = 1;
+ while (true) {
+ try {
+ const layerName = `${i}_Dense`;
+ const weightsResponse = await fetch(`https://huggingface.co/${modelRepo}/resolve/main/${layerName}/model.safetensors`);
+ const configResponse = await fetch(`https://huggingface.co/${modelRepo}/resolve/main/${layerName}/config.json`);
+
+ if (weightsResponse.ok && configResponse.ok) {
+ const weights = new Uint8Array(await weightsResponse.arrayBuffer());
+ const config = new Uint8Array(await configResponse.arrayBuffer());
+ projectionLayers.push({
+ config: config,
+ weights: weights,
+ name: layerName
+ });
+ console.log(`✅ Found ${layerName} layer`);
+ i++;
+ } else if (i === 1) {
+ throw new Error("1_Dense layer is required but not found");
+ } else {
+ // Stop at first missing layer
+ break;
+ }
+ } catch (e) {
+ if (i === 1) {
+ throw new Error("1_Dense layer is required: " + e.message);
+ }
+ break;
+ }
+ }
+
const [tokenizer, model, config, stConfig, dense, denseConfig, tokensConfig] = modelFiles;
- this.state.colbertModel = new ColBERT(model, dense, tokenizer, config, stConfig, denseConfig, tokensConfig, 32);
+ console.log("Creating ColBERT instance with:", {
+ hasModel: !!model,
+ projectionLayerCount: projectionLayers.length,
+ hasTokenizer: !!tokenizer,
+ hasConfig: !!config,
+ hasStConfig: !!stConfig,
+ hasTokensConfig: !!tokensConfig
+ });
+
+ // New constructor signature: (weights, projection_layers, tokenizer, config,
+ // sentence_transformers_config, special_tokens_map, batch_size)
+ this.state.colbertModel = new ColBERT(
+ model,
+ projectionLayers, // Array of {config, weights, name} objects
+ tokenizer,
+ config,
+ stConfig,
+ tokensConfig,
+ 32
+ );
+
+ console.log("✅ ColBERT instance created successfully");
this.updateAllStatuses(`✅ ${modelRepo} loaded successfully.`, 'success');
this.runAllDemos();
} catch (error) {
@@ -1167,15 +1248,34 @@ Pooled Embeddings
* Runs all interactive demos after a model is loaded.
*/
runAllDemos() {
- this.handleSimilarityDemo();
- this.handleMatrixDemo();
- this.handlePoolingDemo();
+ if (!this.state.colbertModel) {
+ console.warn("Model not loaded yet, skipping demos");
+ return;
+ }
+
+ try {
+ this.handleSimilarityDemo();
+ } catch (e) {
+ console.error("Similarity demo failed:", e);
+ }
+
+ try {
+ this.handleMatrixDemo();
+ } catch (e) {
+ console.error("Matrix demo failed:", e);
+ }
+
+ try {
+ this.handlePoolingDemo();
+ } catch (e) {
+ console.error("Pooling demo failed:", e);
+ }
},
/**
* SIMILARITY DEMO: Calculates and displays similarity scores.
*/
- handleSimilarityDemo() {
+ async handleSimilarityDemo() {
if (!this.state.colbertModel) return;
const query = this.ui.queryInput.value;
const documents = this.ui.documentInput.value.split('\n').filter(doc => doc.trim());
@@ -1187,13 +1287,34 @@ Pooled Embeddings
}
try {
- const {
- data
- } = this.state.colbertModel.similarity({
+ console.log("Running similarity with:", { query, documentsCount: documents.length });
+
+ // Call the similarity method - now returns JSON string
+ let resultString = this.state.colbertModel.similarity({
queries: [query],
documents
});
- const scores = data[0];
+
+ console.log("Similarity result type:", typeof resultString);
+ console.log("Similarity result string:", resultString);
+
+ // Parse JSON string to object
+ let result;
+ if (typeof resultString === 'string') {
+ result = JSON.parse(resultString);
+ } else {
+ result = resultString; // Fallback if it's already an object
+ }
+
+ console.log("Parsed result:", result);
+
+ if (!result || !result.data || !result.data[0]) {
+ console.error("Invalid similarity result:", result);
+ this.ui.resultsList.innerHTML = `Error: Model returned invalid results. Check console for details.
`;
+ return;
+ }
+
+ const scores = result.data[0];
const results = documents.map((doc, i) => ({
text: doc,
score: scores[i]
@@ -1238,10 +1359,11 @@ Pooled Embeddings
}
try {
- const rawResult = this.state.colbertModel.raw_similarity_matrix({
+ const rawResultString = this.state.colbertModel.raw_similarity_matrix({
queries: [query],
documents: [doc]
});
+ const rawResult = JSON.parse(rawResultString);
const {
matrix,
queryTokens,
@@ -1426,23 +1548,26 @@ Pooled Embeddings
}
try {
- const originalDocResult = this.state.colbertModel.encode({
+ const originalDocResultString = this.state.colbertModel.encode({
sentences: [docText]
}, false);
+ const originalDocResult = JSON.parse(originalDocResultString);
const originalDocEmbeddings = originalDocResult.embeddings[0] || [];
this.renderEmbeddingBlocks(this.ui.originalEmbeddingsVis, originalDocEmbeddings.length, this.ui.originalTokenCount);
- const pooledResult = hierarchical_pooling({
+ const pooledResultString = hierarchical_pooling({
embeddings: [originalDocEmbeddings],
pool_factor: poolFactor
});
+ const pooledResult = JSON.parse(pooledResultString);
const pooledDocEmbeddings = pooledResult.embeddings[0] || [];
this.renderEmbeddingBlocks(this.ui.pooledEmbeddingsVis, pooledDocEmbeddings.length, this.ui.pooledTokenCount);
if (queryText.trim()) {
- const queryResult = this.state.colbertModel.encode({
+ const queryResultString = this.state.colbertModel.encode({
sentences: [queryText]
}, false);
+ const queryResult = JSON.parse(queryResultString);
const queryEmbeddings = queryResult.embeddings[0] || [];
const calculateScore = (qVecs, dVecs) => {
diff --git a/docs/pkg/pylate_rs.d.ts b/docs/pkg/pylate_rs.d.ts
index 23265cd..3c4c6de 100644
--- a/docs/pkg/pylate_rs.d.ts
+++ b/docs/pkg/pylate_rs.d.ts
@@ -3,7 +3,7 @@
/**
* WASM-compatible version of the `hierarchical_pooling` function.
*/
-export function hierarchical_pooling(input: any): any;
+export function hierarchical_pooling(input: any): string;
/**
* The main ColBERT model structure.
*
@@ -13,33 +13,37 @@ export function hierarchical_pooling(input: any): any;
*/
export class ColBERT {
free(): void;
+ [Symbol.dispose](): void;
/**
* WASM-compatible constructor.
+ *
+ * # Arguments
+ * * `projection_layers` - JsValue containing array of {config: Uint8Array, weights: Uint8Array, name: string}
*/
- constructor(weights: Uint8Array, dense_weights: Uint8Array, tokenizer: Uint8Array, config: Uint8Array, sentence_transformers_config: Uint8Array, dense_config: Uint8Array, special_tokens_map: Uint8Array, batch_size?: number | null);
+ constructor(weights: Uint8Array, projection_layers: any, tokenizer: Uint8Array, config: Uint8Array, sentence_transformers_config: Uint8Array, special_tokens_map: Uint8Array, batch_size?: number | null);
/**
* WASM-compatible version of the `encode` method.
*/
- encode(input: any, is_query: boolean): any;
+ encode(input: any, is_query: boolean): string;
/**
* WASM-compatible version of the `similarity` method.
*/
- similarity(input: any): any;
+ similarity(input: any): string;
/**
* WASM-compatible method to get the raw similarity matrix and tokens.
*/
- raw_similarity_matrix(input: any): any;
+ raw_similarity_matrix(input: any): string;
}
export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module;
export interface InitOutput {
readonly memory: WebAssembly.Memory;
- readonly colbert_from_bytes: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number) => [number, number, number];
- readonly colbert_encode: (a: number, b: any, c: number) => [number, number, number];
- readonly colbert_similarity: (a: number, b: any) => [number, number, number];
- readonly colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number];
- readonly hierarchical_pooling: (a: any) => [number, number, number];
+ readonly colbert_from_bytes: (a: number, b: number, c: any, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number) => [number, number, number];
+ readonly colbert_encode: (a: number, b: any, c: number) => [number, number, number, number];
+ readonly colbert_similarity: (a: number, b: any) => [number, number, number, number];
+ readonly colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number, number];
+ readonly hierarchical_pooling: (a: any) => [number, number, number, number];
readonly __wbg_colbert_free: (a: number, b: number) => void;
readonly __wbindgen_malloc: (a: number, b: number) => number;
readonly __wbindgen_realloc: (a: number, b: number, c: number, d: number) => number;
diff --git a/docs/pkg/pylate_rs.js b/docs/pkg/pylate_rs.js
index ca4b9cf..33776c8 100644
--- a/docs/pkg/pylate_rs.js
+++ b/docs/pkg/pylate_rs.js
@@ -9,16 +9,16 @@ function getUint8ArrayMemory0() {
return cachedUint8ArrayMemory0;
}
-let cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } );
+let cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true });
-if (typeof TextDecoder !== 'undefined') { cachedTextDecoder.decode(); };
+cachedTextDecoder.decode();
const MAX_SAFARI_DECODE_BYTES = 2146435072;
let numBytesDecoded = 0;
function decodeText(ptr, len) {
numBytesDecoded += len;
if (numBytesDecoded >= MAX_SAFARI_DECODE_BYTES) {
- cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } );
+ cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true });
cachedTextDecoder.decode();
numBytesDecoded = len;
}
@@ -32,20 +32,18 @@ function getStringFromWasm0(ptr, len) {
let WASM_VECTOR_LEN = 0;
-const cachedTextEncoder = (typeof TextEncoder !== 'undefined' ? new TextEncoder('utf-8') : { encode: () => { throw Error('TextEncoder not available') } } );
+const cachedTextEncoder = new TextEncoder();
-const encodeString = (typeof cachedTextEncoder.encodeInto === 'function'
- ? function (arg, view) {
- return cachedTextEncoder.encodeInto(arg, view);
+if (!('encodeInto' in cachedTextEncoder)) {
+ cachedTextEncoder.encodeInto = function (arg, view) {
+ const buf = cachedTextEncoder.encode(arg);
+ view.set(buf);
+ return {
+ read: arg.length,
+ written: buf.length
+ };
+ }
}
- : function (arg, view) {
- const buf = cachedTextEncoder.encode(arg);
- view.set(buf);
- return {
- read: arg.length,
- written: buf.length
- };
-});
function passStringToWasm0(arg, malloc, realloc) {
@@ -76,7 +74,7 @@ function passStringToWasm0(arg, malloc, realloc) {
}
ptr = realloc(ptr, len, len = offset + arg.length * 3, 1) >>> 0;
const view = getUint8ArrayMemory0().subarray(ptr + offset, ptr + len);
- const ret = encodeString(arg, view);
+ const ret = cachedTextEncoder.encodeInto(arg, view);
offset += ret.written;
ptr = realloc(ptr, len, offset, 1) >>> 0;
@@ -199,14 +197,25 @@ function takeFromExternrefTable0(idx) {
/**
* WASM-compatible version of the `hierarchical_pooling` function.
* @param {any} input
- * @returns {any}
+ * @returns {string}
*/
export function hierarchical_pooling(input) {
- const ret = wasm.hierarchical_pooling(input);
- if (ret[2]) {
- throw takeFromExternrefTable0(ret[1]);
+ let deferred2_0;
+ let deferred2_1;
+ try {
+ const ret = wasm.hierarchical_pooling(input);
+ var ptr1 = ret[0];
+ var len1 = ret[1];
+ if (ret[3]) {
+ ptr1 = 0; len1 = 0;
+ throw takeFromExternrefTable0(ret[2]);
+ }
+ deferred2_0 = ptr1;
+ deferred2_1 = len1;
+ return getStringFromWasm0(ptr1, len1);
+ } finally {
+ wasm.__wbindgen_free(deferred2_0, deferred2_1, 1);
}
- return takeFromExternrefTable0(ret[0]);
}
const ColBERTFinalization = (typeof FinalizationRegistry === 'undefined')
@@ -234,31 +243,29 @@ export class ColBERT {
}
/**
* WASM-compatible constructor.
+ *
+ * # Arguments
+ * * `projection_layers` - JsValue containing array of {config: Uint8Array, weights: Uint8Array, name: string}
* @param {Uint8Array} weights
- * @param {Uint8Array} dense_weights
+ * @param {any} projection_layers
* @param {Uint8Array} tokenizer
* @param {Uint8Array} config
* @param {Uint8Array} sentence_transformers_config
- * @param {Uint8Array} dense_config
* @param {Uint8Array} special_tokens_map
* @param {number | null} [batch_size]
*/
- constructor(weights, dense_weights, tokenizer, config, sentence_transformers_config, dense_config, special_tokens_map, batch_size) {
+ constructor(weights, projection_layers, tokenizer, config, sentence_transformers_config, special_tokens_map, batch_size) {
const ptr0 = passArray8ToWasm0(weights, wasm.__wbindgen_malloc);
const len0 = WASM_VECTOR_LEN;
- const ptr1 = passArray8ToWasm0(dense_weights, wasm.__wbindgen_malloc);
+ const ptr1 = passArray8ToWasm0(tokenizer, wasm.__wbindgen_malloc);
const len1 = WASM_VECTOR_LEN;
- const ptr2 = passArray8ToWasm0(tokenizer, wasm.__wbindgen_malloc);
+ const ptr2 = passArray8ToWasm0(config, wasm.__wbindgen_malloc);
const len2 = WASM_VECTOR_LEN;
- const ptr3 = passArray8ToWasm0(config, wasm.__wbindgen_malloc);
+ const ptr3 = passArray8ToWasm0(sentence_transformers_config, wasm.__wbindgen_malloc);
const len3 = WASM_VECTOR_LEN;
- const ptr4 = passArray8ToWasm0(sentence_transformers_config, wasm.__wbindgen_malloc);
+ const ptr4 = passArray8ToWasm0(special_tokens_map, wasm.__wbindgen_malloc);
const len4 = WASM_VECTOR_LEN;
- const ptr5 = passArray8ToWasm0(dense_config, wasm.__wbindgen_malloc);
- const len5 = WASM_VECTOR_LEN;
- const ptr6 = passArray8ToWasm0(special_tokens_map, wasm.__wbindgen_malloc);
- const len6 = WASM_VECTOR_LEN;
- const ret = wasm.colbert_from_bytes(ptr0, len0, ptr1, len1, ptr2, len2, ptr3, len3, ptr4, len4, ptr5, len5, ptr6, len6, isLikeNone(batch_size) ? 0x100000001 : (batch_size) >>> 0);
+ const ret = wasm.colbert_from_bytes(ptr0, len0, projection_layers, ptr1, len1, ptr2, len2, ptr3, len3, ptr4, len4, isLikeNone(batch_size) ? 0x100000001 : (batch_size) >>> 0);
if (ret[2]) {
throw takeFromExternrefTable0(ret[1]);
}
@@ -270,40 +277,74 @@ export class ColBERT {
* WASM-compatible version of the `encode` method.
* @param {any} input
* @param {boolean} is_query
- * @returns {any}
+ * @returns {string}
*/
encode(input, is_query) {
- const ret = wasm.colbert_encode(this.__wbg_ptr, input, is_query);
- if (ret[2]) {
- throw takeFromExternrefTable0(ret[1]);
+ let deferred2_0;
+ let deferred2_1;
+ try {
+ const ret = wasm.colbert_encode(this.__wbg_ptr, input, is_query);
+ var ptr1 = ret[0];
+ var len1 = ret[1];
+ if (ret[3]) {
+ ptr1 = 0; len1 = 0;
+ throw takeFromExternrefTable0(ret[2]);
+ }
+ deferred2_0 = ptr1;
+ deferred2_1 = len1;
+ return getStringFromWasm0(ptr1, len1);
+ } finally {
+ wasm.__wbindgen_free(deferred2_0, deferred2_1, 1);
}
- return takeFromExternrefTable0(ret[0]);
}
/**
* WASM-compatible version of the `similarity` method.
* @param {any} input
- * @returns {any}
+ * @returns {string}
*/
similarity(input) {
- const ret = wasm.colbert_similarity(this.__wbg_ptr, input);
- if (ret[2]) {
- throw takeFromExternrefTable0(ret[1]);
+ let deferred2_0;
+ let deferred2_1;
+ try {
+ const ret = wasm.colbert_similarity(this.__wbg_ptr, input);
+ var ptr1 = ret[0];
+ var len1 = ret[1];
+ if (ret[3]) {
+ ptr1 = 0; len1 = 0;
+ throw takeFromExternrefTable0(ret[2]);
+ }
+ deferred2_0 = ptr1;
+ deferred2_1 = len1;
+ return getStringFromWasm0(ptr1, len1);
+ } finally {
+ wasm.__wbindgen_free(deferred2_0, deferred2_1, 1);
}
- return takeFromExternrefTable0(ret[0]);
}
/**
* WASM-compatible method to get the raw similarity matrix and tokens.
* @param {any} input
- * @returns {any}
+ * @returns {string}
*/
raw_similarity_matrix(input) {
- const ret = wasm.colbert_raw_similarity_matrix(this.__wbg_ptr, input);
- if (ret[2]) {
- throw takeFromExternrefTable0(ret[1]);
+ let deferred2_0;
+ let deferred2_1;
+ try {
+ const ret = wasm.colbert_raw_similarity_matrix(this.__wbg_ptr, input);
+ var ptr1 = ret[0];
+ var len1 = ret[1];
+ if (ret[3]) {
+ ptr1 = 0; len1 = 0;
+ throw takeFromExternrefTable0(ret[2]);
+ }
+ deferred2_0 = ptr1;
+ deferred2_1 = len1;
+ return getStringFromWasm0(ptr1, len1);
+ } finally {
+ wasm.__wbindgen_free(deferred2_0, deferred2_1, 1);
}
- return takeFromExternrefTable0(ret[0]);
}
}
+if (Symbol.dispose) ColBERT.prototype[Symbol.dispose] = ColBERT.prototype.free;
const EXPECTED_RESPONSE_TYPES = new Set(['basic', 'cors', 'default']);
@@ -343,10 +384,14 @@ async function __wbg_load(module, imports) {
function __wbg_get_imports() {
const imports = {};
imports.wbg = {};
- imports.wbg.__wbg_Error_0497d5bdba9362e5 = function(arg0, arg1) {
+ imports.wbg.__wbg_Error_e17e777aac105295 = function(arg0, arg1) {
const ret = Error(getStringFromWasm0(arg0, arg1));
return ret;
};
+ imports.wbg.__wbg_Number_998bea33bd87c3e0 = function(arg0) {
+ const ret = Number(arg0);
+ return ret;
+ };
imports.wbg.__wbg_String_8f0eb39a4a4c2f66 = function(arg0, arg1) {
const ret = String(arg1);
const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
@@ -354,15 +399,11 @@ function __wbg_get_imports() {
getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true);
getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true);
};
- imports.wbg.__wbg_buffer_a1a27a0dfa70165d = function(arg0) {
- const ret = arg0.buffer;
- return ret;
- };
- imports.wbg.__wbg_call_fbe8be8bf6436ce5 = function() { return handleError(function (arg0, arg1) {
+ imports.wbg.__wbg_call_13410aac570ffff7 = function() { return handleError(function (arg0, arg1) {
const ret = arg0.call(arg1);
return ret;
}, arguments) };
- imports.wbg.__wbg_done_4d01f352bade43b7 = function(arg0) {
+ imports.wbg.__wbg_done_75ed0ee6dd243d9d = function(arg0) {
const ret = arg0.done;
return ret;
};
@@ -377,22 +418,26 @@ function __wbg_get_imports() {
wasm.__wbindgen_free(deferred0_0, deferred0_1, 1);
}
};
- imports.wbg.__wbg_getRandomValues_3c9c0d586e575a16 = function() { return handleError(function (arg0, arg1) {
- globalThis.crypto.getRandomValues(getArrayU8FromWasm0(arg0, arg1));
- }, arguments) };
- imports.wbg.__wbg_get_92470be87867c2e5 = function() { return handleError(function (arg0, arg1) {
- const ret = Reflect.get(arg0, arg1);
+ imports.wbg.__wbg_from_88bc52ce20ba6318 = function(arg0) {
+ const ret = Array.from(arg0);
return ret;
+ };
+ imports.wbg.__wbg_getRandomValues_1c61fac11405ffdc = function() { return handleError(function (arg0, arg1) {
+ globalThis.crypto.getRandomValues(getArrayU8FromWasm0(arg0, arg1));
}, arguments) };
- imports.wbg.__wbg_get_a131a44bd1eb6979 = function(arg0, arg1) {
+ imports.wbg.__wbg_get_0da715ceaecea5c8 = function(arg0, arg1) {
const ret = arg0[arg1 >>> 0];
return ret;
};
+ imports.wbg.__wbg_get_458e874b43b18b25 = function() { return handleError(function (arg0, arg1) {
+ const ret = Reflect.get(arg0, arg1);
+ return ret;
+ }, arguments) };
imports.wbg.__wbg_getwithrefkey_1dc361bd10053bfe = function(arg0, arg1) {
const ret = arg0[arg1];
return ret;
};
- imports.wbg.__wbg_instanceof_ArrayBuffer_a8b6f580b363f2bc = function(arg0) {
+ imports.wbg.__wbg_instanceof_ArrayBuffer_67f3012529f6a2dd = function(arg0) {
let result;
try {
result = arg0 instanceof ArrayBuffer;
@@ -402,7 +447,7 @@ function __wbg_get_imports() {
const ret = result;
return ret;
};
- imports.wbg.__wbg_instanceof_Uint8Array_ca460677bc155827 = function(arg0) {
+ imports.wbg.__wbg_instanceof_Uint8Array_9a8378d955933db7 = function(arg0) {
let result;
try {
result = arg0 instanceof Uint8Array;
@@ -412,58 +457,44 @@ function __wbg_get_imports() {
const ret = result;
return ret;
};
- imports.wbg.__wbg_isArray_5f090bed72bd4f89 = function(arg0) {
+ imports.wbg.__wbg_isArray_030cce220591fb41 = function(arg0) {
const ret = Array.isArray(arg0);
return ret;
};
- imports.wbg.__wbg_isSafeInteger_90d7c4674047d684 = function(arg0) {
+ imports.wbg.__wbg_isSafeInteger_1c0d1af5542e102a = function(arg0) {
const ret = Number.isSafeInteger(arg0);
return ret;
};
- imports.wbg.__wbg_iterator_4068add5b2aef7a6 = function() {
+ imports.wbg.__wbg_iterator_f370b34483c71a1c = function() {
const ret = Symbol.iterator;
return ret;
};
- imports.wbg.__wbg_length_ab6d22b5ead75c72 = function(arg0) {
+ imports.wbg.__wbg_length_186546c51cd61acd = function(arg0) {
const ret = arg0.length;
return ret;
};
- imports.wbg.__wbg_length_f00ec12454a5d9fd = function(arg0) {
+ imports.wbg.__wbg_length_6bb7e81f9d7713e4 = function(arg0) {
const ret = arg0.length;
return ret;
};
- imports.wbg.__wbg_new_07b483f72211fd66 = function() {
- const ret = new Object();
- return ret;
- };
- imports.wbg.__wbg_new_58353953ad2097cc = function() {
- const ret = new Array();
+ imports.wbg.__wbg_new_638ebfaedbf32a5e = function(arg0) {
+ const ret = new Uint8Array(arg0);
return ret;
};
imports.wbg.__wbg_new_8a6f238a6ece86ea = function() {
const ret = new Error();
return ret;
};
- imports.wbg.__wbg_new_e52b3efaaa774f96 = function(arg0) {
- const ret = new Uint8Array(arg0);
- return ret;
- };
- imports.wbg.__wbg_next_8bb824d217961b5d = function(arg0) {
+ imports.wbg.__wbg_next_5b3530e612fde77d = function(arg0) {
const ret = arg0.next;
return ret;
};
- imports.wbg.__wbg_next_e2da48d8fff7439a = function() { return handleError(function (arg0) {
+ imports.wbg.__wbg_next_692e82279131b03c = function() { return handleError(function (arg0) {
const ret = arg0.next();
return ret;
}, arguments) };
- imports.wbg.__wbg_set_3f1d0b984ed272ed = function(arg0, arg1, arg2) {
- arg0[arg1] = arg2;
- };
- imports.wbg.__wbg_set_7422acbe992d64ab = function(arg0, arg1, arg2) {
- arg0[arg1 >>> 0] = arg2;
- };
- imports.wbg.__wbg_set_fe4e79d1ed3b0e9b = function(arg0, arg1, arg2) {
- arg0.set(arg1, arg2 >>> 0);
+ imports.wbg.__wbg_prototypesetcall_3d4a26c1ed734349 = function(arg0, arg1, arg2) {
+ Uint8Array.prototype.set.call(getArrayU8FromWasm0(arg0, arg1), arg2);
};
imports.wbg.__wbg_stack_0ed75d68575b0f3c = function(arg0, arg1) {
const ret = arg1.stack;
@@ -472,90 +503,64 @@ function __wbg_get_imports() {
getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true);
getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true);
};
- imports.wbg.__wbg_value_17b896954e14f896 = function(arg0) {
+ imports.wbg.__wbg_value_dd9372230531eade = function(arg0) {
const ret = arg0.value;
return ret;
};
- imports.wbg.__wbindgen_as_number = function(arg0) {
- const ret = +arg0;
- return ret;
- };
- imports.wbg.__wbindgen_bigint_from_u64 = function(arg0) {
- const ret = BigInt.asUintN(64, arg0);
- return ret;
- };
- imports.wbg.__wbindgen_bigint_get_as_i64 = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgenbigintgetasi64_ac743ece6ab9bba1 = function(arg0, arg1) {
const v = arg1;
const ret = typeof(v) === 'bigint' ? v : undefined;
getDataViewMemory0().setBigInt64(arg0 + 8 * 1, isLikeNone(ret) ? BigInt(0) : ret, true);
getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true);
};
- imports.wbg.__wbindgen_boolean_get = function(arg0) {
+ imports.wbg.__wbg_wbindgenbooleanget_3fe6f642c7d97746 = function(arg0) {
const v = arg0;
- const ret = typeof(v) === 'boolean' ? (v ? 1 : 0) : 2;
- return ret;
+ const ret = typeof(v) === 'boolean' ? v : undefined;
+ return isLikeNone(ret) ? 0xFFFFFF : ret ? 1 : 0;
};
- imports.wbg.__wbindgen_debug_string = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgendebugstring_99ef257a3ddda34d = function(arg0, arg1) {
const ret = debugString(arg1);
const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
const len1 = WASM_VECTOR_LEN;
getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true);
getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true);
};
- imports.wbg.__wbindgen_in = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgenin_d7a1ee10933d2d55 = function(arg0, arg1) {
const ret = arg0 in arg1;
return ret;
};
- imports.wbg.__wbindgen_init_externref_table = function() {
- const table = wasm.__wbindgen_export_4;
- const offset = table.grow(4);
- table.set(0, undefined);
- table.set(offset + 0, undefined);
- table.set(offset + 1, null);
- table.set(offset + 2, true);
- table.set(offset + 3, false);
- ;
- };
- imports.wbg.__wbindgen_is_bigint = function(arg0) {
+ imports.wbg.__wbg_wbindgenisbigint_ecb90cc08a5a9154 = function(arg0) {
const ret = typeof(arg0) === 'bigint';
return ret;
};
- imports.wbg.__wbindgen_is_function = function(arg0) {
+ imports.wbg.__wbg_wbindgenisfunction_8cee7dce3725ae74 = function(arg0) {
const ret = typeof(arg0) === 'function';
return ret;
};
- imports.wbg.__wbindgen_is_object = function(arg0) {
+ imports.wbg.__wbg_wbindgenisobject_307a53c6bd97fbf8 = function(arg0) {
const val = arg0;
const ret = typeof(val) === 'object' && val !== null;
return ret;
};
- imports.wbg.__wbindgen_is_undefined = function(arg0) {
+ imports.wbg.__wbg_wbindgenisundefined_c4b71d073b92f3c5 = function(arg0) {
const ret = arg0 === undefined;
return ret;
};
- imports.wbg.__wbindgen_jsval_eq = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgenjsvaleq_e6f2ad59ccae1b58 = function(arg0, arg1) {
const ret = arg0 === arg1;
return ret;
};
- imports.wbg.__wbindgen_jsval_loose_eq = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgenjsvallooseeq_9bec8c9be826bed1 = function(arg0, arg1) {
const ret = arg0 == arg1;
return ret;
};
- imports.wbg.__wbindgen_memory = function() {
- const ret = wasm.memory;
- return ret;
- };
- imports.wbg.__wbindgen_number_get = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgennumberget_f74b4c7525ac05cb = function(arg0, arg1) {
const obj = arg1;
const ret = typeof(obj) === 'number' ? obj : undefined;
getDataViewMemory0().setFloat64(arg0 + 8 * 1, isLikeNone(ret) ? 0 : ret, true);
getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true);
};
- imports.wbg.__wbindgen_number_new = function(arg0) {
- const ret = arg0;
- return ret;
- };
- imports.wbg.__wbindgen_string_get = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgenstringget_0f16a6ddddef376f = function(arg0, arg1) {
const obj = arg1;
const ret = typeof(obj) === 'string' ? obj : undefined;
var ptr1 = isLikeNone(ret) ? 0 : passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc);
@@ -563,12 +568,28 @@ function __wbg_get_imports() {
getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true);
getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true);
};
- imports.wbg.__wbindgen_string_new = function(arg0, arg1) {
+ imports.wbg.__wbg_wbindgenthrow_451ec1a8469d7eb6 = function(arg0, arg1) {
+ throw new Error(getStringFromWasm0(arg0, arg1));
+ };
+ imports.wbg.__wbindgen_cast_2241b6af4c4b2941 = function(arg0, arg1) {
+ // Cast intrinsic for `Ref(String) -> Externref`.
const ret = getStringFromWasm0(arg0, arg1);
return ret;
};
- imports.wbg.__wbindgen_throw = function(arg0, arg1) {
- throw new Error(getStringFromWasm0(arg0, arg1));
+ imports.wbg.__wbindgen_cast_4625c577ab2ec9ee = function(arg0) {
+ // Cast intrinsic for `U64 -> Externref`.
+ const ret = BigInt.asUintN(64, arg0);
+ return ret;
+ };
+ imports.wbg.__wbindgen_init_externref_table = function() {
+ const table = wasm.__wbindgen_export_4;
+ const offset = table.grow(4);
+ table.set(0, undefined);
+ table.set(offset + 0, undefined);
+ table.set(offset + 1, null);
+ table.set(offset + 2, true);
+ table.set(offset + 3, false);
+ ;
};
return imports;
diff --git a/docs/pkg/pylate_rs_bg.wasm b/docs/pkg/pylate_rs_bg.wasm
index e256ffa..6112480 100644
Binary files a/docs/pkg/pylate_rs_bg.wasm and b/docs/pkg/pylate_rs_bg.wasm differ
diff --git a/docs/pkg/pylate_rs_bg.wasm.d.ts b/docs/pkg/pylate_rs_bg.wasm.d.ts
index 96243de..ded9f30 100644
--- a/docs/pkg/pylate_rs_bg.wasm.d.ts
+++ b/docs/pkg/pylate_rs_bg.wasm.d.ts
@@ -1,11 +1,11 @@
/* tslint:disable */
/* eslint-disable */
export const memory: WebAssembly.Memory;
-export const colbert_from_bytes: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number) => [number, number, number];
-export const colbert_encode: (a: number, b: any, c: number) => [number, number, number];
-export const colbert_similarity: (a: number, b: any) => [number, number, number];
-export const colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number];
-export const hierarchical_pooling: (a: any) => [number, number, number];
+export const colbert_from_bytes: (a: number, b: number, c: any, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number) => [number, number, number];
+export const colbert_encode: (a: number, b: any, c: number) => [number, number, number, number];
+export const colbert_similarity: (a: number, b: any) => [number, number, number, number];
+export const colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number, number];
+export const hierarchical_pooling: (a: any) => [number, number, number, number];
export const __wbg_colbert_free: (a: number, b: number) => void;
export const __wbindgen_malloc: (a: number, b: number) => number;
export const __wbindgen_realloc: (a: number, b: number, c: number, d: number) => number;
diff --git a/src/builder.rs b/src/builder.rs
index 439cf06..758d211 100644
--- a/src/builder.rs
+++ b/src/builder.rs
@@ -104,22 +104,46 @@ impl TryFrom for ColBERT {
let device = builder.device.unwrap_or(Device::Cpu);
let local_path = PathBuf::from(&builder.repo_id);
+
+ // Auto-detect all *_Dense directories (1_Dense, 2_Dense, 3_Dense, etc.)
let (
tokenizer_path,
weights_path,
config_path,
st_config_path,
- dense_config_path,
- dense_weights_path,
+ projection_layer_paths,
special_tokens_map_path,
) = if local_path.is_dir() {
+ let mut layer_paths = Vec::new();
+
+ // Scan for numbered Dense layers (1_Dense, 2_Dense, ...) until we don't find one
+ let mut i = 1;
+ loop {
+ let layer_dir = format!("{}_Dense", i);
+ let config_path = local_path.join(&layer_dir).join("config.json");
+ let weights_path = local_path.join(&layer_dir).join("model.safetensors");
+
+ if config_path.exists() && weights_path.exists() {
+ layer_paths.push((config_path, weights_path, layer_dir));
+ i += 1;
+ } else if i == 1 {
+ // 1_Dense is required
+ return Err(ColbertError::Io(std::io::Error::new(
+ std::io::ErrorKind::NotFound,
+ "1_Dense layer is required but not found",
+ )));
+ } else {
+ // Stop scanning after first missing layer
+ break;
+ }
+ }
+
(
local_path.join("tokenizer.json"),
local_path.join("model.safetensors"),
local_path.join("config.json"),
local_path.join("config_sentence_transformers.json"),
- local_path.join("1_Dense/config.json"),
- local_path.join("1_Dense/model.safetensors"),
+ layer_paths,
local_path.join("special_tokens_map.json"),
)
} else {
@@ -129,13 +153,39 @@ impl TryFrom for ColBERT {
RepoType::Model,
"main".to_string(),
));
+
+ let mut layer_paths = Vec::new();
+
+ // Scan for numbered Dense layers from HuggingFace Hub until we don't find one
+ let mut i = 1;
+ loop {
+ let layer_dir = format!("{}_Dense", i);
+ let config_file = format!("{}/config.json", layer_dir);
+ let weights_file = format!("{}/model.safetensors", layer_dir);
+
+ match (repo.get(&config_file), repo.get(&weights_file)) {
+ (Ok(config_path), Ok(weights_path)) => {
+ layer_paths.push((config_path, weights_path, layer_dir));
+ i += 1;
+ },
+ _ => {
+ if i == 1 {
+ return Err(ColbertError::Operation(
+ "1_Dense layer is required but not found".into(),
+ ));
+ } else {
+ break;
+ }
+ },
+ }
+ }
+
(
repo.get("tokenizer.json")?,
repo.get("model.safetensors")?,
repo.get("config.json")?,
repo.get("config_sentence_transformers.json")?,
- repo.get("1_Dense/config.json")?,
- repo.get("1_Dense/model.safetensors")?,
+ layer_paths,
repo.get("special_tokens_map.json")?,
)
};
@@ -146,8 +196,6 @@ impl TryFrom for ColBERT {
&weights_path,
&config_path,
&st_config_path,
- &dense_config_path,
- &dense_weights_path,
&special_tokens_map_path,
] {
if !path.exists() {
@@ -163,8 +211,15 @@ impl TryFrom for ColBERT {
let weights_bytes = fs::read(weights_path)?;
let config_bytes = fs::read(config_path)?;
let st_config_bytes = fs::read(st_config_path)?;
- let dense_config_bytes = fs::read(dense_config_path)?;
- let dense_weights_bytes = fs::read(dense_weights_path)?;
+
+ // Read all projection layer files
+ let mut projection_layers_data = Vec::new();
+ for (config_path, weights_path, layer_name) in projection_layer_paths {
+ let config_bytes = fs::read(config_path)?;
+ let weights_bytes = fs::read(weights_path)?;
+ projection_layers_data.push((config_bytes, weights_bytes, layer_name));
+ }
+
let special_tokens_map_bytes = fs::read(special_tokens_map_path)?;
let st_config: serde_json::Value = serde_json::from_slice(&st_config_bytes)?;
@@ -191,11 +246,9 @@ impl TryFrom for ColBERT {
.to_string()
});
- let final_do_query_expansion = builder.do_query_expansion.unwrap_or_else(|| {
- st_config["do_query_expansion"]
- .as_bool()
- .unwrap_or(true)
- });
+ let final_do_query_expansion = builder
+ .do_query_expansion
+ .unwrap_or_else(|| st_config["do_query_expansion"].as_bool().unwrap_or(true));
let final_attend_to_expansion_tokens =
builder.attend_to_expansion_tokens.unwrap_or_else(|| {
@@ -212,10 +265,9 @@ impl TryFrom for ColBERT {
ColBERT::new(
weights_bytes,
- dense_weights_bytes,
+ projection_layers_data,
tokenizer_bytes,
config_bytes,
- dense_config_bytes,
final_query_prefix,
final_document_prefix,
mask_token,
diff --git a/src/model.rs b/src/model.rs
index 8547943..58bc038 100644
--- a/src/model.rs
+++ b/src/model.rs
@@ -9,6 +9,107 @@ use candle_nn::{Linear, Module, VarBuilder};
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use tokenizers::Tokenizer;
+/// Activation functions that can be applied after linear layers.
+#[derive(Debug, Clone, Copy)]
+pub(crate) enum Activation {
+ Identity,
+ Relu,
+ Gelu,
+ GeluErf,
+ Tanh,
+ Silu,
+}
+
+impl Activation {
+ /// Parse activation function from PyTorch class name
+ fn from_pytorch_name(name: &str) -> Result {
+ match name {
+ "torch.nn.modules.linear.Identity" => Ok(Activation::Identity),
+ "torch.nn.modules.activation.ReLU" => Ok(Activation::Relu),
+ "torch.nn.modules.activation.GELU" => Ok(Activation::GeluErf),
+ "torch.nn.modules.activation.Tanh" => Ok(Activation::Tanh),
+ "torch.nn.modules.activation.SiLU" => Ok(Activation::Silu),
+ _ => Err(ColbertError::Operation(format!(
+ "Unsupported activation function: {}",
+ name
+ ))),
+ }
+ }
+
+ /// Apply the activation function to a tensor
+ fn apply(&self, tensor: &Tensor) -> Result {
+ match self {
+ Activation::Identity => Ok(tensor.clone()),
+ Activation::Relu => tensor.relu(),
+ Activation::Gelu => tensor.gelu(),
+ Activation::GeluErf => tensor.gelu_erf(),
+ Activation::Tanh => tensor.tanh(),
+ Activation::Silu => tensor.silu(),
+ }
+ }
+}
+
+/// A single projection layer with its associated activation function.
+#[derive(Clone)]
+pub(crate) struct ProjectionLayer {
+ linear: Linear,
+ activation: Activation,
+}
+
+impl ProjectionLayer {
+ /// Creates a new projection layer from config and weights
+ fn new(
+ config_bytes: &[u8],
+ weights_bytes: Vec,
+ layer_name: &str,
+ device: &Device,
+ ) -> Result {
+ let config: serde_json::Value = serde_json::from_slice(config_bytes)?;
+ let vb = VarBuilder::from_buffered_safetensors(weights_bytes, DType::F32, device)?;
+
+ let in_features = config["in_features"]
+ .as_u64()
+ .map(|v| v as usize)
+ .ok_or_else(|| {
+ ColbertError::Operation(format!("Missing 'in_features' in {} config", layer_name))
+ })?;
+
+ let out_features = config["out_features"]
+ .as_u64()
+ .map(|v| v as usize)
+ .ok_or_else(|| {
+ ColbertError::Operation(format!("Missing 'out_features' in {} config", layer_name))
+ })?;
+
+ // Parse activation function from config, default to Identity if not specified
+ let activation = config["activation_function"]
+ .as_str()
+ .map(Activation::from_pytorch_name)
+ .transpose()?
+ .unwrap_or(Activation::Identity);
+
+ let linear = candle_nn::linear_no_bias(in_features, out_features, vb.pp("linear"))?;
+
+ Ok(Self { linear, activation })
+ }
+
+ /// Apply the projection layer (linear + activation) to a tensor
+ fn forward(&self, tensor: &Tensor) -> Result {
+ let output = self.linear.forward(tensor)?;
+ self.activation.apply(&output)
+ }
+
+ /// Get output dimension of this layer
+ fn out_features(&self) -> usize {
+ self.linear.weight().dims()[0]
+ }
+
+ /// Get input dimension of this layer
+ fn in_features(&self) -> usize {
+ self.linear.weight().dims()[1]
+ }
+}
+
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
use rayon::prelude::*;
@@ -53,7 +154,7 @@ impl BaseModel {
#[cfg_attr(feature = "wasm", wasm_bindgen)]
pub struct ColBERT {
pub(crate) model: BaseModel,
- pub(crate) linear: Linear,
+ pub(crate) projection_layers: Vec,
pub(crate) tokenizer: Tokenizer,
pub(crate) mask_token_id: u32,
pub(crate) mask_token: String,
@@ -71,12 +172,14 @@ pub struct ColBERT {
impl ColBERT {
/// Creates a new instance of the `ColBERT` model from byte buffers.
+ ///
+ /// # Arguments
+ /// * `projection_layers_data` - Vec of (config_bytes, weights_bytes, layer_name) tuples for each Dense layer
pub fn new(
weights: Vec,
- dense_weights: Vec,
+ projection_layers_data: Vec<(Vec, Vec, String)>,
tokenizer_bytes: Vec,
config_bytes: Vec,
- dense_config_bytes: Vec,
query_prefix: String,
document_prefix: String,
mask_token: String,
@@ -92,7 +195,7 @@ impl ColBERT {
let config_value: serde_json::Value = serde_json::from_slice(&config_bytes)?;
let architectures = config_value["architectures"]
.as_array()
- .and_then(|arr| arr.get(0))
+ .and_then(|arr| arr.first())
.and_then(|v| v.as_str())
.ok_or_else(|| {
ColbertError::Operation("Missing or invalid 'architectures' in config.json".into())
@@ -117,7 +220,6 @@ impl ColBERT {
},
};
- let dense_config: serde_json::Value = serde_json::from_slice(&dense_config_bytes)?;
let tokenizer = Tokenizer::from_bytes(&tokenizer_bytes)?;
let mask_token_id = tokenizer.token_to_id(mask_token.as_str()).ok_or_else(|| {
@@ -127,21 +229,38 @@ impl ColBERT {
))
})?;
- let dense_vb = VarBuilder::from_buffered_safetensors(dense_weights, DType::F32, device)?;
- let in_features = dense_config["in_features"]
- .as_u64()
- .map(|v| v as usize)
- .ok_or_else(|| {
- ColbertError::Operation("Missing 'in_features' in dense config".into())
- })?;
- let out_features = dense_config["out_features"]
- .as_u64()
- .map(|v| v as usize)
- .ok_or_else(|| {
- ColbertError::Operation("Missing 'out_features' in dense config".into())
- })?;
+ // Validate that we have at least one projection layer
+ if projection_layers_data.is_empty() {
+ return Err(ColbertError::Operation(
+ "At least one projection layer is required".into(),
+ ));
+ }
- let linear = candle_nn::linear_no_bias(in_features, out_features, dense_vb.pp("linear"))?;
+ // Build projection layers and validate dimensions
+ let mut projection_layers = Vec::with_capacity(projection_layers_data.len());
+ let mut prev_out_features: Option = None;
+
+ for (i, (config_bytes, weights_bytes, layer_name)) in
+ projection_layers_data.into_iter().enumerate()
+ {
+ let layer = ProjectionLayer::new(&config_bytes, weights_bytes, &layer_name, device)?;
+
+ // Validate dimension compatibility with previous layer
+ if let Some(prev_out) = prev_out_features {
+ if prev_out != layer.in_features() {
+ return Err(ColbertError::Operation(format!(
+ "Dimension mismatch between layer {} and {}: output {} != input {}",
+ i - 1,
+ layer_name,
+ prev_out,
+ layer.in_features()
+ )));
+ }
+ }
+
+ prev_out_features = Some(layer.out_features());
+ projection_layers.push(layer);
+ }
// If do_query_expansion is false, attend_to_expansion_tokens should also be false
let final_attend_to_expansion_tokens = if !do_query_expansion {
@@ -152,10 +271,10 @@ impl ColBERT {
Ok(Self {
model,
- linear,
+ projection_layers,
tokenizer,
mask_token_id,
- mask_token: mask_token,
+ mask_token,
query_prefix,
document_prefix,
do_query_expansion,
@@ -268,7 +387,12 @@ impl ColBERT {
let token_embeddings =
self.model
.forward(&token_ids, &attention_mask, &token_type_ids)?;
- let projected_embeddings = self.linear.forward(&token_embeddings)?;
+
+ // Apply all projection layers sequentially
+ let mut projected_embeddings = token_embeddings;
+ for layer in &self.projection_layers {
+ projected_embeddings = layer.forward(&projected_embeddings)?;
+ }
if !self.do_query_expansion || !is_query {
// Apply filtering, normalization, and padding.
@@ -298,7 +422,11 @@ impl ColBERT {
self.model
.forward(&token_ids, &attention_mask, &token_type_ids)?;
- let projected_embeddings = self.linear.forward(&token_embeddings)?;
+ // Apply all projection layers sequentially
+ let mut projected_embeddings = token_embeddings;
+ for layer in &self.projection_layers {
+ projected_embeddings = layer.forward(&projected_embeddings)?;
+ }
let final_embeddings = if !self.do_query_expansion || !is_query {
// Apply filtering, normalization, and padding.
diff --git a/src/wasm.rs b/src/wasm.rs
index ede1754..1662dcc 100644
--- a/src/wasm.rs
+++ b/src/wasm.rs
@@ -2,7 +2,7 @@ use crate::{
error::ColbertError,
model::ColBERT,
pooling::hierarchical_pooling,
- types::{EncodeInput, EncodeOutput, RawSimilarityOutput, SimilarityInput},
+ types::{EncodeInput, EncodeOutput, RawSimilarityOutput, Similarities, SimilarityInput},
};
use candle_core::{Device, IndexOp, Tensor};
use wasm_bindgen::prelude::*;
@@ -10,19 +10,54 @@ use wasm_bindgen::prelude::*;
#[wasm_bindgen]
impl ColBERT {
/// WASM-compatible constructor.
+ ///
+ /// # Arguments
+ /// * `projection_layers` - JsValue containing array of {config: Uint8Array, weights: Uint8Array, name: string}
#[wasm_bindgen(constructor)]
pub fn from_bytes(
weights: Vec,
- dense_weights: Vec,
+ projection_layers: JsValue,
tokenizer: Vec,
config: Vec,
sentence_transformers_config: Vec,
- dense_config: Vec,
special_tokens_map: Vec,
batch_size: Option,
) -> Result {
console_error_panic_hook::set_once();
+ // Parse projection layers from JsValue
+ // Expected format: [{config: Uint8Array, weights: Uint8Array, name: string}, ...]
+ use js_sys::{Array, Object, Reflect, Uint8Array};
+
+ let layers_array = Array::from(&projection_layers);
+ let mut projection_layers_data = Vec::new();
+
+ for i in 0..layers_array.length() {
+ let layer = layers_array.get(i);
+ let layer_obj = Object::from(layer);
+
+ // Extract config as Uint8Array
+ let config_js = Reflect::get(&layer_obj, &JsValue::from_str("config"))
+ .map_err(|e| JsValue::from_str(&format!("Failed to get config: {:?}", e)))?;
+ let config_array = Uint8Array::from(config_js);
+ let config_bytes = config_array.to_vec();
+
+ // Extract weights as Uint8Array
+ let weights_js = Reflect::get(&layer_obj, &JsValue::from_str("weights"))
+ .map_err(|e| JsValue::from_str(&format!("Failed to get weights: {:?}", e)))?;
+ let weights_array = Uint8Array::from(weights_js);
+ let weights_bytes = weights_array.to_vec();
+
+ // Extract name as string
+ let name_js = Reflect::get(&layer_obj, &JsValue::from_str("name"))
+ .map_err(|e| JsValue::from_str(&format!("Failed to get name: {:?}", e)))?;
+ let layer_name = name_js
+ .as_string()
+ .ok_or_else(|| JsValue::from_str("Layer name must be a string"))?;
+
+ projection_layers_data.push((config_bytes, weights_bytes, layer_name));
+ }
+
let st_config: serde_json::Value =
serde_json::from_slice(&sentence_transformers_config).map_err(ColbertError::from)?;
@@ -37,9 +72,7 @@ impl ColBERT {
.as_str()
.unwrap_or("[D]")
.to_string();
- let do_query_expansion = st_config["do_query_expansion"]
- .as_bool()
- .unwrap_or(true);
+ let do_query_expansion = st_config["do_query_expansion"].as_bool().unwrap_or(true);
let attend_to_expansion_tokens = st_config["attend_to_expansion_tokens"]
.as_bool()
.unwrap_or(false);
@@ -55,10 +88,9 @@ impl ColBERT {
Self::new(
weights,
- dense_weights,
+ projection_layers_data,
tokenizer,
config,
- dense_config,
query_prefix,
document_prefix,
mask_token,
@@ -74,7 +106,7 @@ impl ColBERT {
/// WASM-compatible version of the `encode` method.
#[wasm_bindgen(js_name = "encode")]
- pub fn encode_wasm(&mut self, input: JsValue, is_query: bool) -> Result {
+ pub fn encode_wasm(&mut self, input: JsValue, is_query: bool) -> Result {
let params: EncodeInput = serde_wasm_bindgen::from_value(input)?;
// Override model's batch_size if provided in the input
if let Some(batch_size) = params.batch_size {
@@ -87,22 +119,27 @@ impl ColBERT {
let result = EncodeOutput {
embeddings: embeddings_data,
};
- serde_wasm_bindgen::to_value(&result).map_err(Into::into)
+ // Return as JSON string to avoid serde-wasm-bindgen issues
+ serde_json::to_string(&result)
+ .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
/// WASM-compatible version of the `similarity` method.
#[wasm_bindgen(js_name = "similarity")]
- pub fn similarity_wasm(&mut self, input: JsValue) -> Result {
+ pub fn similarity_wasm(&mut self, input: JsValue) -> Result {
let params: SimilarityInput = serde_wasm_bindgen::from_value(input)?;
let queries_embeddings = self.encode(¶ms.queries, true)?;
let documents_embeddings = self.encode(¶ms.documents, false)?;
let result = self.similarity(&queries_embeddings, &documents_embeddings)?;
- serde_wasm_bindgen::to_value(&result).map_err(Into::into)
+
+ // Return as JSON string to avoid serde-wasm-bindgen issues
+ serde_json::to_string(&result)
+ .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
/// WASM-compatible method to get the raw similarity matrix and tokens.
#[wasm_bindgen(js_name = "raw_similarity_matrix")]
- pub fn raw_similarity_matrix_wasm(&mut self, input: JsValue) -> Result {
+ pub fn raw_similarity_matrix_wasm(&mut self, input: JsValue) -> Result {
let params: SimilarityInput = serde_wasm_bindgen::from_value(input)?;
let (query_ids_tensor, _, _) = self.tokenize(¶ms.queries, true)?;
@@ -151,7 +188,9 @@ impl ColBERT {
document_tokens,
};
- serde_wasm_bindgen::to_value(&result).map_err(Into::into)
+ // Return as JSON string to avoid serde-wasm-bindgen issues
+ serde_json::to_string(&result)
+ .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}
}
@@ -165,13 +204,14 @@ struct PoolingInput {
/// WASM-compatible version of the `hierarchical_pooling` function.
#[cfg(feature = "wasm")]
#[wasm_bindgen(js_name = hierarchical_pooling)]
-pub fn hierarchical_pooling_wasm(input: JsValue) -> Result {
+pub fn hierarchical_pooling_wasm(input: JsValue) -> Result {
console_error_panic_hook::set_once();
let params: PoolingInput = serde_wasm_bindgen::from_value(input)?;
if params.embeddings.is_empty() {
let result = EncodeOutput { embeddings: vec![] };
- return serde_wasm_bindgen::to_value(&result).map_err(Into::into);
+ return serde_json::to_string(&result)
+ .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)));
}
let batch_size = params.embeddings.len();
@@ -181,7 +221,8 @@ pub fn hierarchical_pooling_wasm(input: JsValue) -> Result {
let result = EncodeOutput {
embeddings: params.embeddings,
};
- return serde_wasm_bindgen::to_value(&result).map_err(Into::into);
+ return serde_json::to_string(&result)
+ .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)));
}
let embedding_dim = params.embeddings[0][0].len();
@@ -202,5 +243,7 @@ pub fn hierarchical_pooling_wasm(input: JsValue) -> Result {
embeddings: embeddings_data,
};
- serde_wasm_bindgen::to_value(&result).map_err(Into::into)
+ // Return as JSON string to avoid serde-wasm-bindgen issues
+ serde_json::to_string(&result)
+ .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e)))
}