Description
System Info
Version: transformers 3.4.0
OS: Windows 11 24H2
Browser: Chrome 135
node.js environment
Hardware: NVIDIA Tesla P4, Intel UHD 770
Environment/Platform
- Website/web-app
- Browser extension
- Server-side (e.g., Node.js, Deno, Bun)
- Desktop app (e.g., Electron)
- Other (e.g., VSCode extension)
Description
I open this issue as my previous issue #1286 about this problem could be a bit confusing.
Observation
Both whisper ONNX (tested with tiny, base, small, large-v3-turbo) and nllb-200 ONNX (tested with distilled-600M) would produce gibberish output when backend is WebGPU and decoder is q8, despite the model is producing correct output when backend is WASM.
dtype: { encoder_model: 'q8', decoder_model_merged: 'q8' }, device: 'webgpu'
produces gibberish for nllb-200-distilled-600M:
dtype: {encoder_model: 'q8', decoder_model_merged: 'q8' }, device: 'webgpu'
produces gibberish for whisper-small:
What I have done in addition
I have tried quantizing facebook/nllb-200-distilled-600M to q4 using bitsandbytes then converting to ONNX encoder and decoder on my own.
I also tried saving to external data.
Both gives "Error: [WebGPU] Kernel "[MatMul] /layers.0/self_attn/MatMul_1" failed. Error: shared dimension does not match.".
I tried using q8 models from this PR, which gives me error as well.
q4f16 models here cannot properly load as an error is thrown from onnx runtime which is minified and the error code is just a number that I cannot comprehend.
Reproduction
As I can prevent the problem by not using q8 for whisper, I will just post the code for nllb-200
TranslationPipeline.js:
import { InferenceSession, Tensor } from 'onnxruntime-web';
import { AutoTokenizer, AutoModelForSeq2SeqLM } from '@xenova/transformers';
// Translation Pipeline
class TranslationPipeline {
static instance = null;
static model_id = 'Xenova/nllb-200-distilled-600M';
static tokenizer = null;
static model = null;
static async getInstance(progress_callback = null) {
if (this.instance === null) {
this.instance = new TranslationPipeline();
await this.instance.init(progress_callback);
}
return Promise.all([this.instance.tokenizer, this.instance.model]);
}
async init(progress_callback) {
try {
this.tokenizer ??= await AutoTokenizer.from_pretrained(this.constructor.model_id, {
progress_callback,
});
this.model ??= await AutoModelForSeq2SeqLM.from_pretrained(this.constructor.model_id, {
dtype: {
encoder_model: 'q8', // wasm: xenova q8 / debug q4f16, webgpu: debug q4f16
decoder_model_merged: 'q8', // wasm: xenova q8 / debug q4f16, webgpu: debug q4f16
},
//use_external_data_format: true,
device: 'wasm', // both wasm and webgpu work
progress_callback,
});
} catch (error) {
console.error('Error during load:', error);
self.postMessage({
status: 'error',
error: error.message || error.toString(),
});
}
}
}
export default TranslationPipeline;
TranslateWorker.js:
import { TextStreamer, full, env, } from '@xenova/transformers';
import TranslationPipeline from './TranslationPipeline';
// Configure ONNX runtime
env.backends.onnx.wasm = { numThreads: 4, wasmPaths: 'https://MY_LAN_IP:3002/ort-wasm/' };
let processing = false;
async function translate({ text, src_lang, tgt_lang }) {
if (processing) return;
processing = true; // Lock the model instance
// Retrieve the tokenizer and model
const [tokenizer, model] = await TranslationPipeline.getInstance();
try {
// Get the target language token ID
const tgtLangTokenId = tokenizer.encode(tgt_lang)[1];
// Tokenize the input text
const { input_ids, attention_mask } = tokenizer(text, {
return_tensors: 'np',
max_length: 512,
truncation: true,
});
const output = await model.generate({
input_ids: input_ids,
attention_mask: attention_mask,
forced_bos_token_id: tgtLangTokenId, // Force the target language
max_length: 300, // default: 512
num_beams: 5,
early_stopping: true,
//do_sample: true,
//temperature: 0.1,
callback_function: (x) => {
// Decode partial output and send it to the main thread
if (x && x[0] && x[0].output_token_ids) {
const partialOutput = tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true });
self.postMessage({
status: 'update',
output: partialOutput,
});
}
},
});
const bigIntArray = output[0].ort_tensor.cpuData;
const intArray = Array.from(bigIntArray, (bigIntValue) => Number(bigIntValue));
const finalOutput = tokenizer.decode(intArray, { skip_special_tokens: true });
self.postMessage({
status: 'translate_complete',
output: finalOutput,
});
} catch (error) {
self.postMessage({
status: 'error',
data: error.message,
});
//console.warn(error.message)
} finally {
processing = false; // Release the lock
}
}
let loaded = false;
async function load() {
if (loaded) return;
loaded = true;
// Notify the main thread that loading has started
self.postMessage({
status: 'loading',
data: '[Worker] Loading model...'
});
// Load the model and tokenizer
const [tokenizer, model] = await TranslationPipeline.getInstance(x => {
// Forward progress updates to the main thread
self.postMessage(x);
});
// Notify the main thread that the model is loaded
self.postMessage({
status: 'loading',
data: '[Worker] Model loaded.',
});
// This is a warm-up run to compile and cache the shaders
const dummyText = '你好,世界!';
const tgt_lang = 'eng_Latn';
const tgtLangTokenId = tokenizer.encode(tgt_lang)[1];
const { input_ids, attention_mask } = tokenizer(dummyText, { return_tensors: 'np', max_length: 512, truncation: true});
console.log('[Model] Warm-up task: ', dummyText);
const output = await model.generate({
input_ids: input_ids,
attention_mask: attention_mask,
forced_bos_token_id: tgtLangTokenId,
max_length: 10,
num_beams: 5,
early_stopping: true,
});
const bigIntArray = output[0].ort_tensor.cpuData;
const intArray = Array.from(bigIntArray, (bigIntValue) => Number(bigIntValue));
const finalOutput = tokenizer.decode(intArray, { skip_special_tokens: true });
console.log('[Model] Warm-up result: ', finalOutput);
// Notify the main thread that the model is ready
self.postMessage({ status: 'ready', workerType: 'translation' });
}
self.addEventListener('message', async(e) => {
const { type, data } = e.data;
switch(type) {
case 'load':
load();
break;
case 'translate':
translate(data);
break;
}
});