diff --git a/Cargo.toml b/Cargo.toml index 768cd3e..df72e78 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,7 @@ anyhow = "1.0.98" wasm-bindgen = { version = "0.2.100", optional = true } serde-wasm-bindgen = { version = "0.6.5", optional = true } console_error_panic_hook = { version = "0.1.7", optional = true } -getrandom = { version = "0.3", features = ["wasm_js"], optional = true } +js-sys = { version = "0.3", optional = true } hf-hub = { version = "0.4.3", optional = true, default-features = false, features = [ "ureq", "rustls-tls", @@ -35,6 +35,9 @@ pyo3 = { version = "0.25.1", optional = true, features = ["extension-module"] } ndarray = { version = "0.16.1", optional = true } numpy = { version = "0.25.0", optional = true } +[target.'cfg(target_family = "wasm")'.dependencies] +getrandom = { version = "0.2", features = ["js"] } + [features] default = ["tokenizers/onig", "hf-hub"] @@ -42,7 +45,7 @@ wasm = [ "dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:console_error_panic_hook", - "dep:getrandom", + "dep:js-sys", "tokenizers/unstable_wasm", ] diff --git a/docs/index.html b/docs/index.html index d3599e8..b6c271a 100644 --- a/docs/index.html +++ b/docs/index.html @@ -892,6 +892,7 @@

Pooled Embeddings

"lightonai/GTE-ModernColBERT-v1", "lightonai/colbertv2.0", "lightonai/Reason-ModernColBERT", + "mixedbread-ai/mxbai-edge-colbert-v0-17m", ], REQUIRED_FILES: [ 'tokenizer.json', 'model.safetensors', 'config.json', @@ -1006,6 +1007,28 @@

Pooled Embeddings

pooledScore: document.getElementById('pooled-score'), }, + /** + * Retry helper with exponential backoff + */ + async retryWithBackoff(fn, maxRetries = 3, baseDelay = 100) { + for (let i = 0; i < maxRetries; i++) { + try { + return await fn(); + } catch (error) { + const isLastAttempt = i === maxRetries - 1; + const isTableGrowError = error.message && error.message.includes('Table.grow'); + + if (isLastAttempt || !isTableGrowError) { + throw error; + } + + const delay = baseDelay * Math.pow(2, i); + console.warn(`WASM init failed (attempt ${i + 1}/${maxRetries}), retrying in ${delay}ms...`, error.message); + await new Promise(resolve => setTimeout(resolve, delay)); + } + } + }, + /** * Initializes the application, sets up event listeners, and loads the initial model. */ @@ -1016,11 +1039,14 @@

Pooled Embeddings

hljs.highlightAll(); try { - await init(); - this.loadModel(this.state.currentModelName); + await this.retryWithBackoff(() => init(), 5, 200); + console.log("✅ WASM module initialized successfully"); + // Add a small delay to ensure WASM is fully ready + await new Promise(resolve => setTimeout(resolve, 100)); + await this.loadModel(this.state.currentModelName); } catch (e) { - console.error("Fatal: Failed to initialize WASM module.", e); - this.updateAllStatuses(`Error: Could not initialize the application.`, 'error'); + console.error("Fatal: Failed to initialize WASM module after retries.", e); + this.updateAllStatuses(`Error: Could not initialize the application. Please refresh the page.`, 'error'); } }, @@ -1149,9 +1175,64 @@

Pooled Embeddings

console.warn(`Unable to download the model.`, e); } this.updateAllStatuses(`Initializing ${modelRepo}...`, 'loading'); + + // Try to fetch all Dense layers (1_Dense, 2_Dense, 3_Dense, ...) until we don't find one + const projectionLayers = []; + let i = 1; + while (true) { + try { + const layerName = `${i}_Dense`; + const weightsResponse = await fetch(`https://huggingface.co/${modelRepo}/resolve/main/${layerName}/model.safetensors`); + const configResponse = await fetch(`https://huggingface.co/${modelRepo}/resolve/main/${layerName}/config.json`); + + if (weightsResponse.ok && configResponse.ok) { + const weights = new Uint8Array(await weightsResponse.arrayBuffer()); + const config = new Uint8Array(await configResponse.arrayBuffer()); + projectionLayers.push({ + config: config, + weights: weights, + name: layerName + }); + console.log(`✅ Found ${layerName} layer`); + i++; + } else if (i === 1) { + throw new Error("1_Dense layer is required but not found"); + } else { + // Stop at first missing layer + break; + } + } catch (e) { + if (i === 1) { + throw new Error("1_Dense layer is required: " + e.message); + } + break; + } + } + const [tokenizer, model, config, stConfig, dense, denseConfig, tokensConfig] = modelFiles; - this.state.colbertModel = new ColBERT(model, dense, tokenizer, config, stConfig, denseConfig, tokensConfig, 32); + console.log("Creating ColBERT instance with:", { + hasModel: !!model, + projectionLayerCount: projectionLayers.length, + hasTokenizer: !!tokenizer, + hasConfig: !!config, + hasStConfig: !!stConfig, + hasTokensConfig: !!tokensConfig + }); + + // New constructor signature: (weights, projection_layers, tokenizer, config, + // sentence_transformers_config, special_tokens_map, batch_size) + this.state.colbertModel = new ColBERT( + model, + projectionLayers, // Array of {config, weights, name} objects + tokenizer, + config, + stConfig, + tokensConfig, + 32 + ); + + console.log("✅ ColBERT instance created successfully"); this.updateAllStatuses(`✅ ${modelRepo} loaded successfully.`, 'success'); this.runAllDemos(); } catch (error) { @@ -1167,15 +1248,34 @@

Pooled Embeddings

* Runs all interactive demos after a model is loaded. */ runAllDemos() { - this.handleSimilarityDemo(); - this.handleMatrixDemo(); - this.handlePoolingDemo(); + if (!this.state.colbertModel) { + console.warn("Model not loaded yet, skipping demos"); + return; + } + + try { + this.handleSimilarityDemo(); + } catch (e) { + console.error("Similarity demo failed:", e); + } + + try { + this.handleMatrixDemo(); + } catch (e) { + console.error("Matrix demo failed:", e); + } + + try { + this.handlePoolingDemo(); + } catch (e) { + console.error("Pooling demo failed:", e); + } }, /** * SIMILARITY DEMO: Calculates and displays similarity scores. */ - handleSimilarityDemo() { + async handleSimilarityDemo() { if (!this.state.colbertModel) return; const query = this.ui.queryInput.value; const documents = this.ui.documentInput.value.split('\n').filter(doc => doc.trim()); @@ -1187,13 +1287,34 @@

Pooled Embeddings

} try { - const { - data - } = this.state.colbertModel.similarity({ + console.log("Running similarity with:", { query, documentsCount: documents.length }); + + // Call the similarity method - now returns JSON string + let resultString = this.state.colbertModel.similarity({ queries: [query], documents }); - const scores = data[0]; + + console.log("Similarity result type:", typeof resultString); + console.log("Similarity result string:", resultString); + + // Parse JSON string to object + let result; + if (typeof resultString === 'string') { + result = JSON.parse(resultString); + } else { + result = resultString; // Fallback if it's already an object + } + + console.log("Parsed result:", result); + + if (!result || !result.data || !result.data[0]) { + console.error("Invalid similarity result:", result); + this.ui.resultsList.innerHTML = `

Error: Model returned invalid results. Check console for details.

`; + return; + } + + const scores = result.data[0]; const results = documents.map((doc, i) => ({ text: doc, score: scores[i] @@ -1238,10 +1359,11 @@

Pooled Embeddings

} try { - const rawResult = this.state.colbertModel.raw_similarity_matrix({ + const rawResultString = this.state.colbertModel.raw_similarity_matrix({ queries: [query], documents: [doc] }); + const rawResult = JSON.parse(rawResultString); const { matrix, queryTokens, @@ -1426,23 +1548,26 @@

Pooled Embeddings

} try { - const originalDocResult = this.state.colbertModel.encode({ + const originalDocResultString = this.state.colbertModel.encode({ sentences: [docText] }, false); + const originalDocResult = JSON.parse(originalDocResultString); const originalDocEmbeddings = originalDocResult.embeddings[0] || []; this.renderEmbeddingBlocks(this.ui.originalEmbeddingsVis, originalDocEmbeddings.length, this.ui.originalTokenCount); - const pooledResult = hierarchical_pooling({ + const pooledResultString = hierarchical_pooling({ embeddings: [originalDocEmbeddings], pool_factor: poolFactor }); + const pooledResult = JSON.parse(pooledResultString); const pooledDocEmbeddings = pooledResult.embeddings[0] || []; this.renderEmbeddingBlocks(this.ui.pooledEmbeddingsVis, pooledDocEmbeddings.length, this.ui.pooledTokenCount); if (queryText.trim()) { - const queryResult = this.state.colbertModel.encode({ + const queryResultString = this.state.colbertModel.encode({ sentences: [queryText] }, false); + const queryResult = JSON.parse(queryResultString); const queryEmbeddings = queryResult.embeddings[0] || []; const calculateScore = (qVecs, dVecs) => { diff --git a/docs/pkg/pylate_rs.d.ts b/docs/pkg/pylate_rs.d.ts index 23265cd..3c4c6de 100644 --- a/docs/pkg/pylate_rs.d.ts +++ b/docs/pkg/pylate_rs.d.ts @@ -3,7 +3,7 @@ /** * WASM-compatible version of the `hierarchical_pooling` function. */ -export function hierarchical_pooling(input: any): any; +export function hierarchical_pooling(input: any): string; /** * The main ColBERT model structure. * @@ -13,33 +13,37 @@ export function hierarchical_pooling(input: any): any; */ export class ColBERT { free(): void; + [Symbol.dispose](): void; /** * WASM-compatible constructor. + * + * # Arguments + * * `projection_layers` - JsValue containing array of {config: Uint8Array, weights: Uint8Array, name: string} */ - constructor(weights: Uint8Array, dense_weights: Uint8Array, tokenizer: Uint8Array, config: Uint8Array, sentence_transformers_config: Uint8Array, dense_config: Uint8Array, special_tokens_map: Uint8Array, batch_size?: number | null); + constructor(weights: Uint8Array, projection_layers: any, tokenizer: Uint8Array, config: Uint8Array, sentence_transformers_config: Uint8Array, special_tokens_map: Uint8Array, batch_size?: number | null); /** * WASM-compatible version of the `encode` method. */ - encode(input: any, is_query: boolean): any; + encode(input: any, is_query: boolean): string; /** * WASM-compatible version of the `similarity` method. */ - similarity(input: any): any; + similarity(input: any): string; /** * WASM-compatible method to get the raw similarity matrix and tokens. */ - raw_similarity_matrix(input: any): any; + raw_similarity_matrix(input: any): string; } export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module; export interface InitOutput { readonly memory: WebAssembly.Memory; - readonly colbert_from_bytes: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number) => [number, number, number]; - readonly colbert_encode: (a: number, b: any, c: number) => [number, number, number]; - readonly colbert_similarity: (a: number, b: any) => [number, number, number]; - readonly colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number]; - readonly hierarchical_pooling: (a: any) => [number, number, number]; + readonly colbert_from_bytes: (a: number, b: number, c: any, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number) => [number, number, number]; + readonly colbert_encode: (a: number, b: any, c: number) => [number, number, number, number]; + readonly colbert_similarity: (a: number, b: any) => [number, number, number, number]; + readonly colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number, number]; + readonly hierarchical_pooling: (a: any) => [number, number, number, number]; readonly __wbg_colbert_free: (a: number, b: number) => void; readonly __wbindgen_malloc: (a: number, b: number) => number; readonly __wbindgen_realloc: (a: number, b: number, c: number, d: number) => number; diff --git a/docs/pkg/pylate_rs.js b/docs/pkg/pylate_rs.js index ca4b9cf..33776c8 100644 --- a/docs/pkg/pylate_rs.js +++ b/docs/pkg/pylate_rs.js @@ -9,16 +9,16 @@ function getUint8ArrayMemory0() { return cachedUint8ArrayMemory0; } -let cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } ); +let cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }); -if (typeof TextDecoder !== 'undefined') { cachedTextDecoder.decode(); }; +cachedTextDecoder.decode(); const MAX_SAFARI_DECODE_BYTES = 2146435072; let numBytesDecoded = 0; function decodeText(ptr, len) { numBytesDecoded += len; if (numBytesDecoded >= MAX_SAFARI_DECODE_BYTES) { - cachedTextDecoder = (typeof TextDecoder !== 'undefined' ? new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }) : { decode: () => { throw Error('TextDecoder not available') } } ); + cachedTextDecoder = new TextDecoder('utf-8', { ignoreBOM: true, fatal: true }); cachedTextDecoder.decode(); numBytesDecoded = len; } @@ -32,20 +32,18 @@ function getStringFromWasm0(ptr, len) { let WASM_VECTOR_LEN = 0; -const cachedTextEncoder = (typeof TextEncoder !== 'undefined' ? new TextEncoder('utf-8') : { encode: () => { throw Error('TextEncoder not available') } } ); +const cachedTextEncoder = new TextEncoder(); -const encodeString = (typeof cachedTextEncoder.encodeInto === 'function' - ? function (arg, view) { - return cachedTextEncoder.encodeInto(arg, view); +if (!('encodeInto' in cachedTextEncoder)) { + cachedTextEncoder.encodeInto = function (arg, view) { + const buf = cachedTextEncoder.encode(arg); + view.set(buf); + return { + read: arg.length, + written: buf.length + }; + } } - : function (arg, view) { - const buf = cachedTextEncoder.encode(arg); - view.set(buf); - return { - read: arg.length, - written: buf.length - }; -}); function passStringToWasm0(arg, malloc, realloc) { @@ -76,7 +74,7 @@ function passStringToWasm0(arg, malloc, realloc) { } ptr = realloc(ptr, len, len = offset + arg.length * 3, 1) >>> 0; const view = getUint8ArrayMemory0().subarray(ptr + offset, ptr + len); - const ret = encodeString(arg, view); + const ret = cachedTextEncoder.encodeInto(arg, view); offset += ret.written; ptr = realloc(ptr, len, offset, 1) >>> 0; @@ -199,14 +197,25 @@ function takeFromExternrefTable0(idx) { /** * WASM-compatible version of the `hierarchical_pooling` function. * @param {any} input - * @returns {any} + * @returns {string} */ export function hierarchical_pooling(input) { - const ret = wasm.hierarchical_pooling(input); - if (ret[2]) { - throw takeFromExternrefTable0(ret[1]); + let deferred2_0; + let deferred2_1; + try { + const ret = wasm.hierarchical_pooling(input); + var ptr1 = ret[0]; + var len1 = ret[1]; + if (ret[3]) { + ptr1 = 0; len1 = 0; + throw takeFromExternrefTable0(ret[2]); + } + deferred2_0 = ptr1; + deferred2_1 = len1; + return getStringFromWasm0(ptr1, len1); + } finally { + wasm.__wbindgen_free(deferred2_0, deferred2_1, 1); } - return takeFromExternrefTable0(ret[0]); } const ColBERTFinalization = (typeof FinalizationRegistry === 'undefined') @@ -234,31 +243,29 @@ export class ColBERT { } /** * WASM-compatible constructor. + * + * # Arguments + * * `projection_layers` - JsValue containing array of {config: Uint8Array, weights: Uint8Array, name: string} * @param {Uint8Array} weights - * @param {Uint8Array} dense_weights + * @param {any} projection_layers * @param {Uint8Array} tokenizer * @param {Uint8Array} config * @param {Uint8Array} sentence_transformers_config - * @param {Uint8Array} dense_config * @param {Uint8Array} special_tokens_map * @param {number | null} [batch_size] */ - constructor(weights, dense_weights, tokenizer, config, sentence_transformers_config, dense_config, special_tokens_map, batch_size) { + constructor(weights, projection_layers, tokenizer, config, sentence_transformers_config, special_tokens_map, batch_size) { const ptr0 = passArray8ToWasm0(weights, wasm.__wbindgen_malloc); const len0 = WASM_VECTOR_LEN; - const ptr1 = passArray8ToWasm0(dense_weights, wasm.__wbindgen_malloc); + const ptr1 = passArray8ToWasm0(tokenizer, wasm.__wbindgen_malloc); const len1 = WASM_VECTOR_LEN; - const ptr2 = passArray8ToWasm0(tokenizer, wasm.__wbindgen_malloc); + const ptr2 = passArray8ToWasm0(config, wasm.__wbindgen_malloc); const len2 = WASM_VECTOR_LEN; - const ptr3 = passArray8ToWasm0(config, wasm.__wbindgen_malloc); + const ptr3 = passArray8ToWasm0(sentence_transformers_config, wasm.__wbindgen_malloc); const len3 = WASM_VECTOR_LEN; - const ptr4 = passArray8ToWasm0(sentence_transformers_config, wasm.__wbindgen_malloc); + const ptr4 = passArray8ToWasm0(special_tokens_map, wasm.__wbindgen_malloc); const len4 = WASM_VECTOR_LEN; - const ptr5 = passArray8ToWasm0(dense_config, wasm.__wbindgen_malloc); - const len5 = WASM_VECTOR_LEN; - const ptr6 = passArray8ToWasm0(special_tokens_map, wasm.__wbindgen_malloc); - const len6 = WASM_VECTOR_LEN; - const ret = wasm.colbert_from_bytes(ptr0, len0, ptr1, len1, ptr2, len2, ptr3, len3, ptr4, len4, ptr5, len5, ptr6, len6, isLikeNone(batch_size) ? 0x100000001 : (batch_size) >>> 0); + const ret = wasm.colbert_from_bytes(ptr0, len0, projection_layers, ptr1, len1, ptr2, len2, ptr3, len3, ptr4, len4, isLikeNone(batch_size) ? 0x100000001 : (batch_size) >>> 0); if (ret[2]) { throw takeFromExternrefTable0(ret[1]); } @@ -270,40 +277,74 @@ export class ColBERT { * WASM-compatible version of the `encode` method. * @param {any} input * @param {boolean} is_query - * @returns {any} + * @returns {string} */ encode(input, is_query) { - const ret = wasm.colbert_encode(this.__wbg_ptr, input, is_query); - if (ret[2]) { - throw takeFromExternrefTable0(ret[1]); + let deferred2_0; + let deferred2_1; + try { + const ret = wasm.colbert_encode(this.__wbg_ptr, input, is_query); + var ptr1 = ret[0]; + var len1 = ret[1]; + if (ret[3]) { + ptr1 = 0; len1 = 0; + throw takeFromExternrefTable0(ret[2]); + } + deferred2_0 = ptr1; + deferred2_1 = len1; + return getStringFromWasm0(ptr1, len1); + } finally { + wasm.__wbindgen_free(deferred2_0, deferred2_1, 1); } - return takeFromExternrefTable0(ret[0]); } /** * WASM-compatible version of the `similarity` method. * @param {any} input - * @returns {any} + * @returns {string} */ similarity(input) { - const ret = wasm.colbert_similarity(this.__wbg_ptr, input); - if (ret[2]) { - throw takeFromExternrefTable0(ret[1]); + let deferred2_0; + let deferred2_1; + try { + const ret = wasm.colbert_similarity(this.__wbg_ptr, input); + var ptr1 = ret[0]; + var len1 = ret[1]; + if (ret[3]) { + ptr1 = 0; len1 = 0; + throw takeFromExternrefTable0(ret[2]); + } + deferred2_0 = ptr1; + deferred2_1 = len1; + return getStringFromWasm0(ptr1, len1); + } finally { + wasm.__wbindgen_free(deferred2_0, deferred2_1, 1); } - return takeFromExternrefTable0(ret[0]); } /** * WASM-compatible method to get the raw similarity matrix and tokens. * @param {any} input - * @returns {any} + * @returns {string} */ raw_similarity_matrix(input) { - const ret = wasm.colbert_raw_similarity_matrix(this.__wbg_ptr, input); - if (ret[2]) { - throw takeFromExternrefTable0(ret[1]); + let deferred2_0; + let deferred2_1; + try { + const ret = wasm.colbert_raw_similarity_matrix(this.__wbg_ptr, input); + var ptr1 = ret[0]; + var len1 = ret[1]; + if (ret[3]) { + ptr1 = 0; len1 = 0; + throw takeFromExternrefTable0(ret[2]); + } + deferred2_0 = ptr1; + deferred2_1 = len1; + return getStringFromWasm0(ptr1, len1); + } finally { + wasm.__wbindgen_free(deferred2_0, deferred2_1, 1); } - return takeFromExternrefTable0(ret[0]); } } +if (Symbol.dispose) ColBERT.prototype[Symbol.dispose] = ColBERT.prototype.free; const EXPECTED_RESPONSE_TYPES = new Set(['basic', 'cors', 'default']); @@ -343,10 +384,14 @@ async function __wbg_load(module, imports) { function __wbg_get_imports() { const imports = {}; imports.wbg = {}; - imports.wbg.__wbg_Error_0497d5bdba9362e5 = function(arg0, arg1) { + imports.wbg.__wbg_Error_e17e777aac105295 = function(arg0, arg1) { const ret = Error(getStringFromWasm0(arg0, arg1)); return ret; }; + imports.wbg.__wbg_Number_998bea33bd87c3e0 = function(arg0) { + const ret = Number(arg0); + return ret; + }; imports.wbg.__wbg_String_8f0eb39a4a4c2f66 = function(arg0, arg1) { const ret = String(arg1); const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc); @@ -354,15 +399,11 @@ function __wbg_get_imports() { getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; - imports.wbg.__wbg_buffer_a1a27a0dfa70165d = function(arg0) { - const ret = arg0.buffer; - return ret; - }; - imports.wbg.__wbg_call_fbe8be8bf6436ce5 = function() { return handleError(function (arg0, arg1) { + imports.wbg.__wbg_call_13410aac570ffff7 = function() { return handleError(function (arg0, arg1) { const ret = arg0.call(arg1); return ret; }, arguments) }; - imports.wbg.__wbg_done_4d01f352bade43b7 = function(arg0) { + imports.wbg.__wbg_done_75ed0ee6dd243d9d = function(arg0) { const ret = arg0.done; return ret; }; @@ -377,22 +418,26 @@ function __wbg_get_imports() { wasm.__wbindgen_free(deferred0_0, deferred0_1, 1); } }; - imports.wbg.__wbg_getRandomValues_3c9c0d586e575a16 = function() { return handleError(function (arg0, arg1) { - globalThis.crypto.getRandomValues(getArrayU8FromWasm0(arg0, arg1)); - }, arguments) }; - imports.wbg.__wbg_get_92470be87867c2e5 = function() { return handleError(function (arg0, arg1) { - const ret = Reflect.get(arg0, arg1); + imports.wbg.__wbg_from_88bc52ce20ba6318 = function(arg0) { + const ret = Array.from(arg0); return ret; + }; + imports.wbg.__wbg_getRandomValues_1c61fac11405ffdc = function() { return handleError(function (arg0, arg1) { + globalThis.crypto.getRandomValues(getArrayU8FromWasm0(arg0, arg1)); }, arguments) }; - imports.wbg.__wbg_get_a131a44bd1eb6979 = function(arg0, arg1) { + imports.wbg.__wbg_get_0da715ceaecea5c8 = function(arg0, arg1) { const ret = arg0[arg1 >>> 0]; return ret; }; + imports.wbg.__wbg_get_458e874b43b18b25 = function() { return handleError(function (arg0, arg1) { + const ret = Reflect.get(arg0, arg1); + return ret; + }, arguments) }; imports.wbg.__wbg_getwithrefkey_1dc361bd10053bfe = function(arg0, arg1) { const ret = arg0[arg1]; return ret; }; - imports.wbg.__wbg_instanceof_ArrayBuffer_a8b6f580b363f2bc = function(arg0) { + imports.wbg.__wbg_instanceof_ArrayBuffer_67f3012529f6a2dd = function(arg0) { let result; try { result = arg0 instanceof ArrayBuffer; @@ -402,7 +447,7 @@ function __wbg_get_imports() { const ret = result; return ret; }; - imports.wbg.__wbg_instanceof_Uint8Array_ca460677bc155827 = function(arg0) { + imports.wbg.__wbg_instanceof_Uint8Array_9a8378d955933db7 = function(arg0) { let result; try { result = arg0 instanceof Uint8Array; @@ -412,58 +457,44 @@ function __wbg_get_imports() { const ret = result; return ret; }; - imports.wbg.__wbg_isArray_5f090bed72bd4f89 = function(arg0) { + imports.wbg.__wbg_isArray_030cce220591fb41 = function(arg0) { const ret = Array.isArray(arg0); return ret; }; - imports.wbg.__wbg_isSafeInteger_90d7c4674047d684 = function(arg0) { + imports.wbg.__wbg_isSafeInteger_1c0d1af5542e102a = function(arg0) { const ret = Number.isSafeInteger(arg0); return ret; }; - imports.wbg.__wbg_iterator_4068add5b2aef7a6 = function() { + imports.wbg.__wbg_iterator_f370b34483c71a1c = function() { const ret = Symbol.iterator; return ret; }; - imports.wbg.__wbg_length_ab6d22b5ead75c72 = function(arg0) { + imports.wbg.__wbg_length_186546c51cd61acd = function(arg0) { const ret = arg0.length; return ret; }; - imports.wbg.__wbg_length_f00ec12454a5d9fd = function(arg0) { + imports.wbg.__wbg_length_6bb7e81f9d7713e4 = function(arg0) { const ret = arg0.length; return ret; }; - imports.wbg.__wbg_new_07b483f72211fd66 = function() { - const ret = new Object(); - return ret; - }; - imports.wbg.__wbg_new_58353953ad2097cc = function() { - const ret = new Array(); + imports.wbg.__wbg_new_638ebfaedbf32a5e = function(arg0) { + const ret = new Uint8Array(arg0); return ret; }; imports.wbg.__wbg_new_8a6f238a6ece86ea = function() { const ret = new Error(); return ret; }; - imports.wbg.__wbg_new_e52b3efaaa774f96 = function(arg0) { - const ret = new Uint8Array(arg0); - return ret; - }; - imports.wbg.__wbg_next_8bb824d217961b5d = function(arg0) { + imports.wbg.__wbg_next_5b3530e612fde77d = function(arg0) { const ret = arg0.next; return ret; }; - imports.wbg.__wbg_next_e2da48d8fff7439a = function() { return handleError(function (arg0) { + imports.wbg.__wbg_next_692e82279131b03c = function() { return handleError(function (arg0) { const ret = arg0.next(); return ret; }, arguments) }; - imports.wbg.__wbg_set_3f1d0b984ed272ed = function(arg0, arg1, arg2) { - arg0[arg1] = arg2; - }; - imports.wbg.__wbg_set_7422acbe992d64ab = function(arg0, arg1, arg2) { - arg0[arg1 >>> 0] = arg2; - }; - imports.wbg.__wbg_set_fe4e79d1ed3b0e9b = function(arg0, arg1, arg2) { - arg0.set(arg1, arg2 >>> 0); + imports.wbg.__wbg_prototypesetcall_3d4a26c1ed734349 = function(arg0, arg1, arg2) { + Uint8Array.prototype.set.call(getArrayU8FromWasm0(arg0, arg1), arg2); }; imports.wbg.__wbg_stack_0ed75d68575b0f3c = function(arg0, arg1) { const ret = arg1.stack; @@ -472,90 +503,64 @@ function __wbg_get_imports() { getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; - imports.wbg.__wbg_value_17b896954e14f896 = function(arg0) { + imports.wbg.__wbg_value_dd9372230531eade = function(arg0) { const ret = arg0.value; return ret; }; - imports.wbg.__wbindgen_as_number = function(arg0) { - const ret = +arg0; - return ret; - }; - imports.wbg.__wbindgen_bigint_from_u64 = function(arg0) { - const ret = BigInt.asUintN(64, arg0); - return ret; - }; - imports.wbg.__wbindgen_bigint_get_as_i64 = function(arg0, arg1) { + imports.wbg.__wbg_wbindgenbigintgetasi64_ac743ece6ab9bba1 = function(arg0, arg1) { const v = arg1; const ret = typeof(v) === 'bigint' ? v : undefined; getDataViewMemory0().setBigInt64(arg0 + 8 * 1, isLikeNone(ret) ? BigInt(0) : ret, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true); }; - imports.wbg.__wbindgen_boolean_get = function(arg0) { + imports.wbg.__wbg_wbindgenbooleanget_3fe6f642c7d97746 = function(arg0) { const v = arg0; - const ret = typeof(v) === 'boolean' ? (v ? 1 : 0) : 2; - return ret; + const ret = typeof(v) === 'boolean' ? v : undefined; + return isLikeNone(ret) ? 0xFFFFFF : ret ? 1 : 0; }; - imports.wbg.__wbindgen_debug_string = function(arg0, arg1) { + imports.wbg.__wbg_wbindgendebugstring_99ef257a3ddda34d = function(arg0, arg1) { const ret = debugString(arg1); const ptr1 = passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc); const len1 = WASM_VECTOR_LEN; getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; - imports.wbg.__wbindgen_in = function(arg0, arg1) { + imports.wbg.__wbg_wbindgenin_d7a1ee10933d2d55 = function(arg0, arg1) { const ret = arg0 in arg1; return ret; }; - imports.wbg.__wbindgen_init_externref_table = function() { - const table = wasm.__wbindgen_export_4; - const offset = table.grow(4); - table.set(0, undefined); - table.set(offset + 0, undefined); - table.set(offset + 1, null); - table.set(offset + 2, true); - table.set(offset + 3, false); - ; - }; - imports.wbg.__wbindgen_is_bigint = function(arg0) { + imports.wbg.__wbg_wbindgenisbigint_ecb90cc08a5a9154 = function(arg0) { const ret = typeof(arg0) === 'bigint'; return ret; }; - imports.wbg.__wbindgen_is_function = function(arg0) { + imports.wbg.__wbg_wbindgenisfunction_8cee7dce3725ae74 = function(arg0) { const ret = typeof(arg0) === 'function'; return ret; }; - imports.wbg.__wbindgen_is_object = function(arg0) { + imports.wbg.__wbg_wbindgenisobject_307a53c6bd97fbf8 = function(arg0) { const val = arg0; const ret = typeof(val) === 'object' && val !== null; return ret; }; - imports.wbg.__wbindgen_is_undefined = function(arg0) { + imports.wbg.__wbg_wbindgenisundefined_c4b71d073b92f3c5 = function(arg0) { const ret = arg0 === undefined; return ret; }; - imports.wbg.__wbindgen_jsval_eq = function(arg0, arg1) { + imports.wbg.__wbg_wbindgenjsvaleq_e6f2ad59ccae1b58 = function(arg0, arg1) { const ret = arg0 === arg1; return ret; }; - imports.wbg.__wbindgen_jsval_loose_eq = function(arg0, arg1) { + imports.wbg.__wbg_wbindgenjsvallooseeq_9bec8c9be826bed1 = function(arg0, arg1) { const ret = arg0 == arg1; return ret; }; - imports.wbg.__wbindgen_memory = function() { - const ret = wasm.memory; - return ret; - }; - imports.wbg.__wbindgen_number_get = function(arg0, arg1) { + imports.wbg.__wbg_wbindgennumberget_f74b4c7525ac05cb = function(arg0, arg1) { const obj = arg1; const ret = typeof(obj) === 'number' ? obj : undefined; getDataViewMemory0().setFloat64(arg0 + 8 * 1, isLikeNone(ret) ? 0 : ret, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, !isLikeNone(ret), true); }; - imports.wbg.__wbindgen_number_new = function(arg0) { - const ret = arg0; - return ret; - }; - imports.wbg.__wbindgen_string_get = function(arg0, arg1) { + imports.wbg.__wbg_wbindgenstringget_0f16a6ddddef376f = function(arg0, arg1) { const obj = arg1; const ret = typeof(obj) === 'string' ? obj : undefined; var ptr1 = isLikeNone(ret) ? 0 : passStringToWasm0(ret, wasm.__wbindgen_malloc, wasm.__wbindgen_realloc); @@ -563,12 +568,28 @@ function __wbg_get_imports() { getDataViewMemory0().setInt32(arg0 + 4 * 1, len1, true); getDataViewMemory0().setInt32(arg0 + 4 * 0, ptr1, true); }; - imports.wbg.__wbindgen_string_new = function(arg0, arg1) { + imports.wbg.__wbg_wbindgenthrow_451ec1a8469d7eb6 = function(arg0, arg1) { + throw new Error(getStringFromWasm0(arg0, arg1)); + }; + imports.wbg.__wbindgen_cast_2241b6af4c4b2941 = function(arg0, arg1) { + // Cast intrinsic for `Ref(String) -> Externref`. const ret = getStringFromWasm0(arg0, arg1); return ret; }; - imports.wbg.__wbindgen_throw = function(arg0, arg1) { - throw new Error(getStringFromWasm0(arg0, arg1)); + imports.wbg.__wbindgen_cast_4625c577ab2ec9ee = function(arg0) { + // Cast intrinsic for `U64 -> Externref`. + const ret = BigInt.asUintN(64, arg0); + return ret; + }; + imports.wbg.__wbindgen_init_externref_table = function() { + const table = wasm.__wbindgen_export_4; + const offset = table.grow(4); + table.set(0, undefined); + table.set(offset + 0, undefined); + table.set(offset + 1, null); + table.set(offset + 2, true); + table.set(offset + 3, false); + ; }; return imports; diff --git a/docs/pkg/pylate_rs_bg.wasm b/docs/pkg/pylate_rs_bg.wasm index e256ffa..6112480 100644 Binary files a/docs/pkg/pylate_rs_bg.wasm and b/docs/pkg/pylate_rs_bg.wasm differ diff --git a/docs/pkg/pylate_rs_bg.wasm.d.ts b/docs/pkg/pylate_rs_bg.wasm.d.ts index 96243de..ded9f30 100644 --- a/docs/pkg/pylate_rs_bg.wasm.d.ts +++ b/docs/pkg/pylate_rs_bg.wasm.d.ts @@ -1,11 +1,11 @@ /* tslint:disable */ /* eslint-disable */ export const memory: WebAssembly.Memory; -export const colbert_from_bytes: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number) => [number, number, number]; -export const colbert_encode: (a: number, b: any, c: number) => [number, number, number]; -export const colbert_similarity: (a: number, b: any) => [number, number, number]; -export const colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number]; -export const hierarchical_pooling: (a: any) => [number, number, number]; +export const colbert_from_bytes: (a: number, b: number, c: any, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number) => [number, number, number]; +export const colbert_encode: (a: number, b: any, c: number) => [number, number, number, number]; +export const colbert_similarity: (a: number, b: any) => [number, number, number, number]; +export const colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number, number]; +export const hierarchical_pooling: (a: any) => [number, number, number, number]; export const __wbg_colbert_free: (a: number, b: number) => void; export const __wbindgen_malloc: (a: number, b: number) => number; export const __wbindgen_realloc: (a: number, b: number, c: number, d: number) => number; diff --git a/src/builder.rs b/src/builder.rs index 439cf06..758d211 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -104,22 +104,46 @@ impl TryFrom for ColBERT { let device = builder.device.unwrap_or(Device::Cpu); let local_path = PathBuf::from(&builder.repo_id); + + // Auto-detect all *_Dense directories (1_Dense, 2_Dense, 3_Dense, etc.) let ( tokenizer_path, weights_path, config_path, st_config_path, - dense_config_path, - dense_weights_path, + projection_layer_paths, special_tokens_map_path, ) = if local_path.is_dir() { + let mut layer_paths = Vec::new(); + + // Scan for numbered Dense layers (1_Dense, 2_Dense, ...) until we don't find one + let mut i = 1; + loop { + let layer_dir = format!("{}_Dense", i); + let config_path = local_path.join(&layer_dir).join("config.json"); + let weights_path = local_path.join(&layer_dir).join("model.safetensors"); + + if config_path.exists() && weights_path.exists() { + layer_paths.push((config_path, weights_path, layer_dir)); + i += 1; + } else if i == 1 { + // 1_Dense is required + return Err(ColbertError::Io(std::io::Error::new( + std::io::ErrorKind::NotFound, + "1_Dense layer is required but not found", + ))); + } else { + // Stop scanning after first missing layer + break; + } + } + ( local_path.join("tokenizer.json"), local_path.join("model.safetensors"), local_path.join("config.json"), local_path.join("config_sentence_transformers.json"), - local_path.join("1_Dense/config.json"), - local_path.join("1_Dense/model.safetensors"), + layer_paths, local_path.join("special_tokens_map.json"), ) } else { @@ -129,13 +153,39 @@ impl TryFrom for ColBERT { RepoType::Model, "main".to_string(), )); + + let mut layer_paths = Vec::new(); + + // Scan for numbered Dense layers from HuggingFace Hub until we don't find one + let mut i = 1; + loop { + let layer_dir = format!("{}_Dense", i); + let config_file = format!("{}/config.json", layer_dir); + let weights_file = format!("{}/model.safetensors", layer_dir); + + match (repo.get(&config_file), repo.get(&weights_file)) { + (Ok(config_path), Ok(weights_path)) => { + layer_paths.push((config_path, weights_path, layer_dir)); + i += 1; + }, + _ => { + if i == 1 { + return Err(ColbertError::Operation( + "1_Dense layer is required but not found".into(), + )); + } else { + break; + } + }, + } + } + ( repo.get("tokenizer.json")?, repo.get("model.safetensors")?, repo.get("config.json")?, repo.get("config_sentence_transformers.json")?, - repo.get("1_Dense/config.json")?, - repo.get("1_Dense/model.safetensors")?, + layer_paths, repo.get("special_tokens_map.json")?, ) }; @@ -146,8 +196,6 @@ impl TryFrom for ColBERT { &weights_path, &config_path, &st_config_path, - &dense_config_path, - &dense_weights_path, &special_tokens_map_path, ] { if !path.exists() { @@ -163,8 +211,15 @@ impl TryFrom for ColBERT { let weights_bytes = fs::read(weights_path)?; let config_bytes = fs::read(config_path)?; let st_config_bytes = fs::read(st_config_path)?; - let dense_config_bytes = fs::read(dense_config_path)?; - let dense_weights_bytes = fs::read(dense_weights_path)?; + + // Read all projection layer files + let mut projection_layers_data = Vec::new(); + for (config_path, weights_path, layer_name) in projection_layer_paths { + let config_bytes = fs::read(config_path)?; + let weights_bytes = fs::read(weights_path)?; + projection_layers_data.push((config_bytes, weights_bytes, layer_name)); + } + let special_tokens_map_bytes = fs::read(special_tokens_map_path)?; let st_config: serde_json::Value = serde_json::from_slice(&st_config_bytes)?; @@ -191,11 +246,9 @@ impl TryFrom for ColBERT { .to_string() }); - let final_do_query_expansion = builder.do_query_expansion.unwrap_or_else(|| { - st_config["do_query_expansion"] - .as_bool() - .unwrap_or(true) - }); + let final_do_query_expansion = builder + .do_query_expansion + .unwrap_or_else(|| st_config["do_query_expansion"].as_bool().unwrap_or(true)); let final_attend_to_expansion_tokens = builder.attend_to_expansion_tokens.unwrap_or_else(|| { @@ -212,10 +265,9 @@ impl TryFrom for ColBERT { ColBERT::new( weights_bytes, - dense_weights_bytes, + projection_layers_data, tokenizer_bytes, config_bytes, - dense_config_bytes, final_query_prefix, final_document_prefix, mask_token, diff --git a/src/model.rs b/src/model.rs index 8547943..58bc038 100644 --- a/src/model.rs +++ b/src/model.rs @@ -9,6 +9,107 @@ use candle_nn::{Linear, Module, VarBuilder}; use candle_transformers::models::bert::{BertModel, Config as BertConfig}; use tokenizers::Tokenizer; +/// Activation functions that can be applied after linear layers. +#[derive(Debug, Clone, Copy)] +pub(crate) enum Activation { + Identity, + Relu, + Gelu, + GeluErf, + Tanh, + Silu, +} + +impl Activation { + /// Parse activation function from PyTorch class name + fn from_pytorch_name(name: &str) -> Result { + match name { + "torch.nn.modules.linear.Identity" => Ok(Activation::Identity), + "torch.nn.modules.activation.ReLU" => Ok(Activation::Relu), + "torch.nn.modules.activation.GELU" => Ok(Activation::GeluErf), + "torch.nn.modules.activation.Tanh" => Ok(Activation::Tanh), + "torch.nn.modules.activation.SiLU" => Ok(Activation::Silu), + _ => Err(ColbertError::Operation(format!( + "Unsupported activation function: {}", + name + ))), + } + } + + /// Apply the activation function to a tensor + fn apply(&self, tensor: &Tensor) -> Result { + match self { + Activation::Identity => Ok(tensor.clone()), + Activation::Relu => tensor.relu(), + Activation::Gelu => tensor.gelu(), + Activation::GeluErf => tensor.gelu_erf(), + Activation::Tanh => tensor.tanh(), + Activation::Silu => tensor.silu(), + } + } +} + +/// A single projection layer with its associated activation function. +#[derive(Clone)] +pub(crate) struct ProjectionLayer { + linear: Linear, + activation: Activation, +} + +impl ProjectionLayer { + /// Creates a new projection layer from config and weights + fn new( + config_bytes: &[u8], + weights_bytes: Vec, + layer_name: &str, + device: &Device, + ) -> Result { + let config: serde_json::Value = serde_json::from_slice(config_bytes)?; + let vb = VarBuilder::from_buffered_safetensors(weights_bytes, DType::F32, device)?; + + let in_features = config["in_features"] + .as_u64() + .map(|v| v as usize) + .ok_or_else(|| { + ColbertError::Operation(format!("Missing 'in_features' in {} config", layer_name)) + })?; + + let out_features = config["out_features"] + .as_u64() + .map(|v| v as usize) + .ok_or_else(|| { + ColbertError::Operation(format!("Missing 'out_features' in {} config", layer_name)) + })?; + + // Parse activation function from config, default to Identity if not specified + let activation = config["activation_function"] + .as_str() + .map(Activation::from_pytorch_name) + .transpose()? + .unwrap_or(Activation::Identity); + + let linear = candle_nn::linear_no_bias(in_features, out_features, vb.pp("linear"))?; + + Ok(Self { linear, activation }) + } + + /// Apply the projection layer (linear + activation) to a tensor + fn forward(&self, tensor: &Tensor) -> Result { + let output = self.linear.forward(tensor)?; + self.activation.apply(&output) + } + + /// Get output dimension of this layer + fn out_features(&self) -> usize { + self.linear.weight().dims()[0] + } + + /// Get input dimension of this layer + fn in_features(&self) -> usize { + self.linear.weight().dims()[1] + } +} + #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] use rayon::prelude::*; @@ -53,7 +154,7 @@ impl BaseModel { #[cfg_attr(feature = "wasm", wasm_bindgen)] pub struct ColBERT { pub(crate) model: BaseModel, - pub(crate) linear: Linear, + pub(crate) projection_layers: Vec, pub(crate) tokenizer: Tokenizer, pub(crate) mask_token_id: u32, pub(crate) mask_token: String, @@ -71,12 +172,14 @@ pub struct ColBERT { impl ColBERT { /// Creates a new instance of the `ColBERT` model from byte buffers. + /// + /// # Arguments + /// * `projection_layers_data` - Vec of (config_bytes, weights_bytes, layer_name) tuples for each Dense layer pub fn new( weights: Vec, - dense_weights: Vec, + projection_layers_data: Vec<(Vec, Vec, String)>, tokenizer_bytes: Vec, config_bytes: Vec, - dense_config_bytes: Vec, query_prefix: String, document_prefix: String, mask_token: String, @@ -92,7 +195,7 @@ impl ColBERT { let config_value: serde_json::Value = serde_json::from_slice(&config_bytes)?; let architectures = config_value["architectures"] .as_array() - .and_then(|arr| arr.get(0)) + .and_then(|arr| arr.first()) .and_then(|v| v.as_str()) .ok_or_else(|| { ColbertError::Operation("Missing or invalid 'architectures' in config.json".into()) @@ -117,7 +220,6 @@ impl ColBERT { }, }; - let dense_config: serde_json::Value = serde_json::from_slice(&dense_config_bytes)?; let tokenizer = Tokenizer::from_bytes(&tokenizer_bytes)?; let mask_token_id = tokenizer.token_to_id(mask_token.as_str()).ok_or_else(|| { @@ -127,21 +229,38 @@ impl ColBERT { )) })?; - let dense_vb = VarBuilder::from_buffered_safetensors(dense_weights, DType::F32, device)?; - let in_features = dense_config["in_features"] - .as_u64() - .map(|v| v as usize) - .ok_or_else(|| { - ColbertError::Operation("Missing 'in_features' in dense config".into()) - })?; - let out_features = dense_config["out_features"] - .as_u64() - .map(|v| v as usize) - .ok_or_else(|| { - ColbertError::Operation("Missing 'out_features' in dense config".into()) - })?; + // Validate that we have at least one projection layer + if projection_layers_data.is_empty() { + return Err(ColbertError::Operation( + "At least one projection layer is required".into(), + )); + } - let linear = candle_nn::linear_no_bias(in_features, out_features, dense_vb.pp("linear"))?; + // Build projection layers and validate dimensions + let mut projection_layers = Vec::with_capacity(projection_layers_data.len()); + let mut prev_out_features: Option = None; + + for (i, (config_bytes, weights_bytes, layer_name)) in + projection_layers_data.into_iter().enumerate() + { + let layer = ProjectionLayer::new(&config_bytes, weights_bytes, &layer_name, device)?; + + // Validate dimension compatibility with previous layer + if let Some(prev_out) = prev_out_features { + if prev_out != layer.in_features() { + return Err(ColbertError::Operation(format!( + "Dimension mismatch between layer {} and {}: output {} != input {}", + i - 1, + layer_name, + prev_out, + layer.in_features() + ))); + } + } + + prev_out_features = Some(layer.out_features()); + projection_layers.push(layer); + } // If do_query_expansion is false, attend_to_expansion_tokens should also be false let final_attend_to_expansion_tokens = if !do_query_expansion { @@ -152,10 +271,10 @@ impl ColBERT { Ok(Self { model, - linear, + projection_layers, tokenizer, mask_token_id, - mask_token: mask_token, + mask_token, query_prefix, document_prefix, do_query_expansion, @@ -268,7 +387,12 @@ impl ColBERT { let token_embeddings = self.model .forward(&token_ids, &attention_mask, &token_type_ids)?; - let projected_embeddings = self.linear.forward(&token_embeddings)?; + + // Apply all projection layers sequentially + let mut projected_embeddings = token_embeddings; + for layer in &self.projection_layers { + projected_embeddings = layer.forward(&projected_embeddings)?; + } if !self.do_query_expansion || !is_query { // Apply filtering, normalization, and padding. @@ -298,7 +422,11 @@ impl ColBERT { self.model .forward(&token_ids, &attention_mask, &token_type_ids)?; - let projected_embeddings = self.linear.forward(&token_embeddings)?; + // Apply all projection layers sequentially + let mut projected_embeddings = token_embeddings; + for layer in &self.projection_layers { + projected_embeddings = layer.forward(&projected_embeddings)?; + } let final_embeddings = if !self.do_query_expansion || !is_query { // Apply filtering, normalization, and padding. diff --git a/src/wasm.rs b/src/wasm.rs index ede1754..1662dcc 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -2,7 +2,7 @@ use crate::{ error::ColbertError, model::ColBERT, pooling::hierarchical_pooling, - types::{EncodeInput, EncodeOutput, RawSimilarityOutput, SimilarityInput}, + types::{EncodeInput, EncodeOutput, RawSimilarityOutput, Similarities, SimilarityInput}, }; use candle_core::{Device, IndexOp, Tensor}; use wasm_bindgen::prelude::*; @@ -10,19 +10,54 @@ use wasm_bindgen::prelude::*; #[wasm_bindgen] impl ColBERT { /// WASM-compatible constructor. + /// + /// # Arguments + /// * `projection_layers` - JsValue containing array of {config: Uint8Array, weights: Uint8Array, name: string} #[wasm_bindgen(constructor)] pub fn from_bytes( weights: Vec, - dense_weights: Vec, + projection_layers: JsValue, tokenizer: Vec, config: Vec, sentence_transformers_config: Vec, - dense_config: Vec, special_tokens_map: Vec, batch_size: Option, ) -> Result { console_error_panic_hook::set_once(); + // Parse projection layers from JsValue + // Expected format: [{config: Uint8Array, weights: Uint8Array, name: string}, ...] + use js_sys::{Array, Object, Reflect, Uint8Array}; + + let layers_array = Array::from(&projection_layers); + let mut projection_layers_data = Vec::new(); + + for i in 0..layers_array.length() { + let layer = layers_array.get(i); + let layer_obj = Object::from(layer); + + // Extract config as Uint8Array + let config_js = Reflect::get(&layer_obj, &JsValue::from_str("config")) + .map_err(|e| JsValue::from_str(&format!("Failed to get config: {:?}", e)))?; + let config_array = Uint8Array::from(config_js); + let config_bytes = config_array.to_vec(); + + // Extract weights as Uint8Array + let weights_js = Reflect::get(&layer_obj, &JsValue::from_str("weights")) + .map_err(|e| JsValue::from_str(&format!("Failed to get weights: {:?}", e)))?; + let weights_array = Uint8Array::from(weights_js); + let weights_bytes = weights_array.to_vec(); + + // Extract name as string + let name_js = Reflect::get(&layer_obj, &JsValue::from_str("name")) + .map_err(|e| JsValue::from_str(&format!("Failed to get name: {:?}", e)))?; + let layer_name = name_js + .as_string() + .ok_or_else(|| JsValue::from_str("Layer name must be a string"))?; + + projection_layers_data.push((config_bytes, weights_bytes, layer_name)); + } + let st_config: serde_json::Value = serde_json::from_slice(&sentence_transformers_config).map_err(ColbertError::from)?; @@ -37,9 +72,7 @@ impl ColBERT { .as_str() .unwrap_or("[D]") .to_string(); - let do_query_expansion = st_config["do_query_expansion"] - .as_bool() - .unwrap_or(true); + let do_query_expansion = st_config["do_query_expansion"].as_bool().unwrap_or(true); let attend_to_expansion_tokens = st_config["attend_to_expansion_tokens"] .as_bool() .unwrap_or(false); @@ -55,10 +88,9 @@ impl ColBERT { Self::new( weights, - dense_weights, + projection_layers_data, tokenizer, config, - dense_config, query_prefix, document_prefix, mask_token, @@ -74,7 +106,7 @@ impl ColBERT { /// WASM-compatible version of the `encode` method. #[wasm_bindgen(js_name = "encode")] - pub fn encode_wasm(&mut self, input: JsValue, is_query: bool) -> Result { + pub fn encode_wasm(&mut self, input: JsValue, is_query: bool) -> Result { let params: EncodeInput = serde_wasm_bindgen::from_value(input)?; // Override model's batch_size if provided in the input if let Some(batch_size) = params.batch_size { @@ -87,22 +119,27 @@ impl ColBERT { let result = EncodeOutput { embeddings: embeddings_data, }; - serde_wasm_bindgen::to_value(&result).map_err(Into::into) + // Return as JSON string to avoid serde-wasm-bindgen issues + serde_json::to_string(&result) + .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e))) } /// WASM-compatible version of the `similarity` method. #[wasm_bindgen(js_name = "similarity")] - pub fn similarity_wasm(&mut self, input: JsValue) -> Result { + pub fn similarity_wasm(&mut self, input: JsValue) -> Result { let params: SimilarityInput = serde_wasm_bindgen::from_value(input)?; let queries_embeddings = self.encode(¶ms.queries, true)?; let documents_embeddings = self.encode(¶ms.documents, false)?; let result = self.similarity(&queries_embeddings, &documents_embeddings)?; - serde_wasm_bindgen::to_value(&result).map_err(Into::into) + + // Return as JSON string to avoid serde-wasm-bindgen issues + serde_json::to_string(&result) + .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e))) } /// WASM-compatible method to get the raw similarity matrix and tokens. #[wasm_bindgen(js_name = "raw_similarity_matrix")] - pub fn raw_similarity_matrix_wasm(&mut self, input: JsValue) -> Result { + pub fn raw_similarity_matrix_wasm(&mut self, input: JsValue) -> Result { let params: SimilarityInput = serde_wasm_bindgen::from_value(input)?; let (query_ids_tensor, _, _) = self.tokenize(¶ms.queries, true)?; @@ -151,7 +188,9 @@ impl ColBERT { document_tokens, }; - serde_wasm_bindgen::to_value(&result).map_err(Into::into) + // Return as JSON string to avoid serde-wasm-bindgen issues + serde_json::to_string(&result) + .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e))) } } @@ -165,13 +204,14 @@ struct PoolingInput { /// WASM-compatible version of the `hierarchical_pooling` function. #[cfg(feature = "wasm")] #[wasm_bindgen(js_name = hierarchical_pooling)] -pub fn hierarchical_pooling_wasm(input: JsValue) -> Result { +pub fn hierarchical_pooling_wasm(input: JsValue) -> Result { console_error_panic_hook::set_once(); let params: PoolingInput = serde_wasm_bindgen::from_value(input)?; if params.embeddings.is_empty() { let result = EncodeOutput { embeddings: vec![] }; - return serde_wasm_bindgen::to_value(&result).map_err(Into::into); + return serde_json::to_string(&result) + .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e))); } let batch_size = params.embeddings.len(); @@ -181,7 +221,8 @@ pub fn hierarchical_pooling_wasm(input: JsValue) -> Result { let result = EncodeOutput { embeddings: params.embeddings, }; - return serde_wasm_bindgen::to_value(&result).map_err(Into::into); + return serde_json::to_string(&result) + .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e))); } let embedding_dim = params.embeddings[0][0].len(); @@ -202,5 +243,7 @@ pub fn hierarchical_pooling_wasm(input: JsValue) -> Result { embeddings: embeddings_data, }; - serde_wasm_bindgen::to_value(&result).map_err(Into::into) + // Return as JSON string to avoid serde-wasm-bindgen issues + serde_json::to_string(&result) + .map_err(|e| JsValue::from_str(&format!("Serialization error: {}", e))) }