Skip to content

WebGPU does not work with q8 decoders (AutoModelForSeq2SeqLM, WhisperForConditionalGeneration) #1317

Open
@SignOfZeta

Description

@SignOfZeta

System Info

Version: transformers 3.4.0
OS: Windows 11 24H2
Browser: Chrome 135
node.js environment
Hardware: NVIDIA Tesla P4, Intel UHD 770

Environment/Platform

  • Website/web-app
  • Browser extension
  • Server-side (e.g., Node.js, Deno, Bun)
  • Desktop app (e.g., Electron)
  • Other (e.g., VSCode extension)

Description

I open this issue as my previous issue #1286 about this problem could be a bit confusing.

Observation

Both whisper ONNX (tested with tiny, base, small, large-v3-turbo) and nllb-200 ONNX (tested with distilled-600M) would produce gibberish output when backend is WebGPU and decoder is q8, despite the model is producing correct output when backend is WASM.

dtype: { encoder_model: 'q8', decoder_model_merged: 'q8' }, device: 'webgpu' produces gibberish for nllb-200-distilled-600M:
Image

dtype: {encoder_model: 'q8', decoder_model_merged: 'q8' }, device: 'webgpu' produces gibberish for whisper-small:
Image

What I have done in addition

I have tried quantizing facebook/nllb-200-distilled-600M to q4 using bitsandbytes then converting to ONNX encoder and decoder on my own.
I also tried saving to external data.
Both gives "Error: [WebGPU] Kernel "[MatMul] /layers.0/self_attn/MatMul_1" failed. Error: shared dimension does not match.".

I tried using q8 models from this PR, which gives me error as well.
q4f16 models here cannot properly load as an error is thrown from onnx runtime which is minified and the error code is just a number that I cannot comprehend.

Reproduction

As I can prevent the problem by not using q8 for whisper, I will just post the code for nllb-200

TranslationPipeline.js:

import { InferenceSession, Tensor } from 'onnxruntime-web';
import { AutoTokenizer, AutoModelForSeq2SeqLM } from '@xenova/transformers';

// Translation Pipeline
class TranslationPipeline {
    static instance = null;
    static model_id = 'Xenova/nllb-200-distilled-600M';
    static tokenizer = null;
    static model = null;

    static async getInstance(progress_callback = null) {
	if (this.instance === null) {
            this.instance = new TranslationPipeline();
            await this.instance.init(progress_callback);
        }
        return Promise.all([this.instance.tokenizer, this.instance.model]);
    }

    async init(progress_callback) {
        try {
            this.tokenizer ??= await AutoTokenizer.from_pretrained(this.constructor.model_id, {
                progress_callback,
            });
            this.model ??= await AutoModelForSeq2SeqLM.from_pretrained(this.constructor.model_id, {
                dtype: {
                    encoder_model: 'q8', // wasm: xenova q8 / debug q4f16, webgpu: debug q4f16
                    decoder_model_merged: 'q8', // wasm: xenova q8 / debug q4f16, webgpu: debug q4f16
                },
		//use_external_data_format: true,
                device: 'wasm', // both wasm and webgpu work
                progress_callback,
            });
        } catch (error) {
            console.error('Error during load:', error);
            self.postMessage({
                status: 'error',
                error: error.message || error.toString(),
            });
        }
    }
}

export default TranslationPipeline;

TranslateWorker.js:

import { TextStreamer, full, env, } from '@xenova/transformers';
import TranslationPipeline from './TranslationPipeline';

// Configure ONNX runtime
env.backends.onnx.wasm = { numThreads: 4, wasmPaths: 'https://MY_LAN_IP:3002/ort-wasm/' };

let processing = false;
async function translate({ text, src_lang, tgt_lang }) {
    if (processing) return;
    processing = true; // Lock the model instance

    // Retrieve the tokenizer and model
    const [tokenizer, model] = await TranslationPipeline.getInstance();

    try {
        // Get the target language token ID
        const tgtLangTokenId = tokenizer.encode(tgt_lang)[1];

        // Tokenize the input text
        const { input_ids, attention_mask } = tokenizer(text, {
            return_tensors: 'np',
            max_length: 512,
            truncation: true,
        });

        const output = await model.generate({
			input_ids: input_ids,
			attention_mask: attention_mask,
            forced_bos_token_id: tgtLangTokenId, // Force the target language
            max_length: 300, // default: 512
            num_beams: 5,
            early_stopping: true,
			//do_sample: true, 
			//temperature: 0.1,
            callback_function: (x) => {
                // Decode partial output and send it to the main thread
                if (x && x[0] && x[0].output_token_ids) {
                    const partialOutput = tokenizer.decode(x[0].output_token_ids, { skip_special_tokens: true });
                    self.postMessage({
                        status: 'update',
                        output: partialOutput,
                    });
                }
            },
        });

        const bigIntArray = output[0].ort_tensor.cpuData;
		const intArray = Array.from(bigIntArray, (bigIntValue) => Number(bigIntValue));
		const finalOutput = tokenizer.decode(intArray, { skip_special_tokens: true });

        self.postMessage({
            status: 'translate_complete',
            output: finalOutput,
        });
    } catch (error) {
        self.postMessage({
            status: 'error',
			data: error.message,
        });
		//console.warn(error.message)
    } finally {
        processing = false; // Release the lock
    }
}

let loaded = false;
async function load() {
    if (loaded) return;
    loaded = true;

    // Notify the main thread that loading has started
    self.postMessage({
        status: 'loading',
        data: '[Worker] Loading model...'
    });

	// Load the model and tokenizer
	const [tokenizer, model] = await TranslationPipeline.getInstance(x => {
		// Forward progress updates to the main thread
		self.postMessage(x);
	});

	// Notify the main thread that the model is loaded
	self.postMessage({
		status: 'loading',
		data: '[Worker] Model loaded.',
	});
	
	// This is a warm-up run to compile and cache the shaders
	const dummyText = '你好,世界!';
	
	const tgt_lang = 'eng_Latn';
	const tgtLangTokenId = tokenizer.encode(tgt_lang)[1];
	
	const { input_ids, attention_mask } = tokenizer(dummyText, { return_tensors: 'np', max_length: 512, truncation: true});
	
	console.log('[Model] Warm-up task: ', dummyText);
	
	const output = await model.generate({
		input_ids: input_ids,
		attention_mask: attention_mask,
		forced_bos_token_id: tgtLangTokenId,
		max_length: 10,
		num_beams: 5,
		early_stopping: true,
	});
	
	const bigIntArray = output[0].ort_tensor.cpuData;
	const intArray = Array.from(bigIntArray, (bigIntValue) => Number(bigIntValue));
	
	const finalOutput = tokenizer.decode(intArray, { skip_special_tokens: true });
	console.log('[Model] Warm-up result: ', finalOutput);

	// Notify the main thread that the model is ready
	self.postMessage({ status: 'ready', workerType: 'translation' });
}

self.addEventListener('message', async(e) => {
	const { type, data } = e.data;
	
	switch(type) {
		case 'load':
			load();
			break;
		case 'translate':
			translate(data);
			break;
	}
});

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions