Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ anyhow = "1.0.98"
wasm-bindgen = { version = "0.2.100", optional = true }
serde-wasm-bindgen = { version = "0.6.5", optional = true }
console_error_panic_hook = { version = "0.1.7", optional = true }
getrandom = { version = "0.3", features = ["wasm_js"], optional = true }
js-sys = { version = "0.3", optional = true }
hf-hub = { version = "0.4.3", optional = true, default-features = false, features = [
"ureq",
"rustls-tls",
Expand All @@ -35,14 +35,17 @@ pyo3 = { version = "0.25.1", optional = true, features = ["extension-module"] }
ndarray = { version = "0.16.1", optional = true }
numpy = { version = "0.25.0", optional = true }

[target.'cfg(target_family = "wasm")'.dependencies]
getrandom = { version = "0.2", features = ["js"] }

[features]
default = ["tokenizers/onig", "hf-hub"]

wasm = [
"dep:wasm-bindgen",
"dep:serde-wasm-bindgen",
"dep:console_error_panic_hook",
"dep:getrandom",
"dep:js-sys",
"tokenizers/unstable_wasm",
]

Expand Down
159 changes: 142 additions & 17 deletions docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
"lightonai/GTE-ModernColBERT-v1",
"lightonai/colbertv2.0",
"lightonai/Reason-ModernColBERT",
"mixedbread-ai/mxbai-edge-colbert-v0-17m",
],
REQUIRED_FILES: [
'tokenizer.json', 'model.safetensors', 'config.json',
Expand Down Expand Up @@ -1006,6 +1007,28 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
pooledScore: document.getElementById('pooled-score'),
},

/**
* Retry helper with exponential backoff
*/
async retryWithBackoff(fn, maxRetries = 3, baseDelay = 100) {
for (let i = 0; i < maxRetries; i++) {
try {
return await fn();
} catch (error) {
const isLastAttempt = i === maxRetries - 1;
const isTableGrowError = error.message && error.message.includes('Table.grow');

if (isLastAttempt || !isTableGrowError) {
throw error;
}

const delay = baseDelay * Math.pow(2, i);
console.warn(`WASM init failed (attempt ${i + 1}/${maxRetries}), retrying in ${delay}ms...`, error.message);
await new Promise(resolve => setTimeout(resolve, delay));
}
}
},

/**
* Initializes the application, sets up event listeners, and loads the initial model.
*/
Expand All @@ -1016,11 +1039,14 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
hljs.highlightAll();

try {
await init();
this.loadModel(this.state.currentModelName);
await this.retryWithBackoff(() => init(), 5, 200);
console.log("✅ WASM module initialized successfully");
// Add a small delay to ensure WASM is fully ready
await new Promise(resolve => setTimeout(resolve, 100));
await this.loadModel(this.state.currentModelName);
} catch (e) {
console.error("Fatal: Failed to initialize WASM module.", e);
this.updateAllStatuses(`Error: Could not initialize the application.`, 'error');
console.error("Fatal: Failed to initialize WASM module after retries.", e);
this.updateAllStatuses(`Error: Could not initialize the application. Please refresh the page.`, 'error');
}
},

Expand Down Expand Up @@ -1149,9 +1175,64 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
console.warn(`Unable to download the model.`, e);
}
this.updateAllStatuses(`Initializing ${modelRepo}...`, 'loading');

// Try to fetch all Dense layers (1_Dense, 2_Dense, 3_Dense, ...) until we don't find one
const projectionLayers = [];
let i = 1;
while (true) {
try {
const layerName = `${i}_Dense`;
const weightsResponse = await fetch(`https://huggingface.co/${modelRepo}/resolve/main/${layerName}/model.safetensors`);
const configResponse = await fetch(`https://huggingface.co/${modelRepo}/resolve/main/${layerName}/config.json`);

if (weightsResponse.ok && configResponse.ok) {
const weights = new Uint8Array(await weightsResponse.arrayBuffer());
const config = new Uint8Array(await configResponse.arrayBuffer());
projectionLayers.push({
config: config,
weights: weights,
name: layerName
});
console.log(`✅ Found ${layerName} layer`);
i++;
} else if (i === 1) {
throw new Error("1_Dense layer is required but not found");
} else {
// Stop at first missing layer
break;
}
} catch (e) {
if (i === 1) {
throw new Error("1_Dense layer is required: " + e.message);
}
break;
}
}

const [tokenizer, model, config, stConfig, dense, denseConfig, tokensConfig] = modelFiles;
this.state.colbertModel = new ColBERT(model, dense, tokenizer, config, stConfig, denseConfig, tokensConfig, 32);

console.log("Creating ColBERT instance with:", {
hasModel: !!model,
projectionLayerCount: projectionLayers.length,
hasTokenizer: !!tokenizer,
hasConfig: !!config,
hasStConfig: !!stConfig,
hasTokensConfig: !!tokensConfig
});

// New constructor signature: (weights, projection_layers, tokenizer, config,
// sentence_transformers_config, special_tokens_map, batch_size)
this.state.colbertModel = new ColBERT(
model,
projectionLayers, // Array of {config, weights, name} objects
tokenizer,
config,
stConfig,
tokensConfig,
32
);

console.log("✅ ColBERT instance created successfully");
this.updateAllStatuses(`✅ ${modelRepo} loaded successfully.`, 'success');
this.runAllDemos();
} catch (error) {
Expand All @@ -1167,15 +1248,34 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
* Runs all interactive demos after a model is loaded.
*/
runAllDemos() {
this.handleSimilarityDemo();
this.handleMatrixDemo();
this.handlePoolingDemo();
if (!this.state.colbertModel) {
console.warn("Model not loaded yet, skipping demos");
return;
}

try {
this.handleSimilarityDemo();
} catch (e) {
console.error("Similarity demo failed:", e);
}

try {
this.handleMatrixDemo();
} catch (e) {
console.error("Matrix demo failed:", e);
}

try {
this.handlePoolingDemo();
} catch (e) {
console.error("Pooling demo failed:", e);
}
},

/**
* SIMILARITY DEMO: Calculates and displays similarity scores.
*/
handleSimilarityDemo() {
async handleSimilarityDemo() {
if (!this.state.colbertModel) return;
const query = this.ui.queryInput.value;
const documents = this.ui.documentInput.value.split('\n').filter(doc => doc.trim());
Expand All @@ -1187,13 +1287,34 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
}

try {
const {
data
} = this.state.colbertModel.similarity({
console.log("Running similarity with:", { query, documentsCount: documents.length });

// Call the similarity method - now returns JSON string
let resultString = this.state.colbertModel.similarity({
queries: [query],
documents
});
const scores = data[0];

console.log("Similarity result type:", typeof resultString);
console.log("Similarity result string:", resultString);

// Parse JSON string to object
let result;
if (typeof resultString === 'string') {
result = JSON.parse(resultString);
} else {
result = resultString; // Fallback if it's already an object
}

console.log("Parsed result:", result);

if (!result || !result.data || !result.data[0]) {
console.error("Invalid similarity result:", result);
this.ui.resultsList.innerHTML = `<p class="text-red-500 p-2">Error: Model returned invalid results. Check console for details.</p>`;
return;
}

const scores = result.data[0];
const results = documents.map((doc, i) => ({
text: doc,
score: scores[i]
Expand Down Expand Up @@ -1238,10 +1359,11 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
}

try {
const rawResult = this.state.colbertModel.raw_similarity_matrix({
const rawResultString = this.state.colbertModel.raw_similarity_matrix({
queries: [query],
documents: [doc]
});
const rawResult = JSON.parse(rawResultString);
const {
matrix,
queryTokens,
Expand Down Expand Up @@ -1426,23 +1548,26 @@ <h3 class="text-base font-semibold text-gray-800 mb-3">Pooled Embeddings</h3>
}

try {
const originalDocResult = this.state.colbertModel.encode({
const originalDocResultString = this.state.colbertModel.encode({
sentences: [docText]
}, false);
const originalDocResult = JSON.parse(originalDocResultString);
const originalDocEmbeddings = originalDocResult.embeddings[0] || [];
this.renderEmbeddingBlocks(this.ui.originalEmbeddingsVis, originalDocEmbeddings.length, this.ui.originalTokenCount);

const pooledResult = hierarchical_pooling({
const pooledResultString = hierarchical_pooling({
embeddings: [originalDocEmbeddings],
pool_factor: poolFactor
});
const pooledResult = JSON.parse(pooledResultString);
const pooledDocEmbeddings = pooledResult.embeddings[0] || [];
this.renderEmbeddingBlocks(this.ui.pooledEmbeddingsVis, pooledDocEmbeddings.length, this.ui.pooledTokenCount);

if (queryText.trim()) {
const queryResult = this.state.colbertModel.encode({
const queryResultString = this.state.colbertModel.encode({
sentences: [queryText]
}, false);
const queryResult = JSON.parse(queryResultString);
const queryEmbeddings = queryResult.embeddings[0] || [];

const calculateScore = (qVecs, dVecs) => {
Expand Down
24 changes: 14 additions & 10 deletions docs/pkg/pylate_rs.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
/**
* WASM-compatible version of the `hierarchical_pooling` function.
*/
export function hierarchical_pooling(input: any): any;
export function hierarchical_pooling(input: any): string;
/**
* The main ColBERT model structure.
*
Expand All @@ -13,33 +13,37 @@ export function hierarchical_pooling(input: any): any;
*/
export class ColBERT {
free(): void;
[Symbol.dispose](): void;
/**
* WASM-compatible constructor.
*
* # Arguments
* * `projection_layers` - JsValue containing array of {config: Uint8Array, weights: Uint8Array, name: string}
*/
constructor(weights: Uint8Array, dense_weights: Uint8Array, tokenizer: Uint8Array, config: Uint8Array, sentence_transformers_config: Uint8Array, dense_config: Uint8Array, special_tokens_map: Uint8Array, batch_size?: number | null);
constructor(weights: Uint8Array, projection_layers: any, tokenizer: Uint8Array, config: Uint8Array, sentence_transformers_config: Uint8Array, special_tokens_map: Uint8Array, batch_size?: number | null);
/**
* WASM-compatible version of the `encode` method.
*/
encode(input: any, is_query: boolean): any;
encode(input: any, is_query: boolean): string;
/**
* WASM-compatible version of the `similarity` method.
*/
similarity(input: any): any;
similarity(input: any): string;
/**
* WASM-compatible method to get the raw similarity matrix and tokens.
*/
raw_similarity_matrix(input: any): any;
raw_similarity_matrix(input: any): string;
}

export type InitInput = RequestInfo | URL | Response | BufferSource | WebAssembly.Module;

export interface InitOutput {
readonly memory: WebAssembly.Memory;
readonly colbert_from_bytes: (a: number, b: number, c: number, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number, m: number, n: number, o: number) => [number, number, number];
readonly colbert_encode: (a: number, b: any, c: number) => [number, number, number];
readonly colbert_similarity: (a: number, b: any) => [number, number, number];
readonly colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number];
readonly hierarchical_pooling: (a: any) => [number, number, number];
readonly colbert_from_bytes: (a: number, b: number, c: any, d: number, e: number, f: number, g: number, h: number, i: number, j: number, k: number, l: number) => [number, number, number];
readonly colbert_encode: (a: number, b: any, c: number) => [number, number, number, number];
readonly colbert_similarity: (a: number, b: any) => [number, number, number, number];
readonly colbert_raw_similarity_matrix: (a: number, b: any) => [number, number, number, number];
readonly hierarchical_pooling: (a: any) => [number, number, number, number];
readonly __wbg_colbert_free: (a: number, b: number) => void;
readonly __wbindgen_malloc: (a: number, b: number) => number;
readonly __wbindgen_realloc: (a: number, b: number, c: number, d: number) => number;
Expand Down
Loading