diff --git a/.ci/docker/ci_commit_pins/optimum-executorch.txt b/.ci/docker/ci_commit_pins/optimum-executorch.txt index 574ccb745d0..e0af9344588 100644 --- a/.ci/docker/ci_commit_pins/optimum-executorch.txt +++ b/.ci/docker/ci_commit_pins/optimum-executorch.txt @@ -1 +1 @@ -467660923a5a25e4718e1d6697b93ff1bab4e807 +4361747abfc55e40e929396ed986efe775d745f9 diff --git a/.ci/scripts/export_model_cuda_artifact.sh b/.ci/scripts/export_model_cuda_artifact.sh new file mode 100755 index 00000000000..85e34ae5b80 --- /dev/null +++ b/.ci/scripts/export_model_cuda_artifact.sh @@ -0,0 +1,150 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Export model to CUDA format with optional quantization + +show_help() { + cat << EOF +Usage: export_model_cuda_artifact.sh [quant_name] [output_dir] + +Export a HuggingFace model to CUDA format with optional quantization. + +Arguments: + hf_model HuggingFace model ID (required) + Supported models: + - mistralai/Voxtral-Mini-3B-2507 + - openai/whisper-small + - google/gemma-3-4b-it + + quant_name Quantization type (optional, default: non-quantized) + Options: + - non-quantized + - quantized-int4-tile-packed + - quantized-int4-weight-only + + output_dir Output directory for artifacts (optional, default: current directory) + +Examples: + export_model_cuda_artifact.sh "openai/whisper-small" + export_model_cuda_artifact.sh "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" + export_model_cuda_artifact.sh "google/gemma-3-4b-it" "non-quantized" "./output" +EOF +} + +if [ "${1:-}" = "-h" ] || [ "${1:-}" = "--help" ]; then + show_help + exit 0 +fi + +if [ -z "${1:-}" ]; then + echo "Error: hf_model argument is required" + echo "Run with -h or --help for usage information" + exit 1 +fi + +set -eux + +HF_MODEL="$1" +QUANT_NAME="${2:-non-quantized}" +OUTPUT_DIR="${3:-.}" + +# Determine model configuration based on HF model ID +case "$HF_MODEL" in + mistralai/Voxtral-Mini-3B-2507) + MODEL_NAME="voxtral" + TASK="multimodal-text-to-text" + MAX_SEQ_LEN="1024" + EXTRA_PIP="mistral-common librosa" + PREPROCESSOR_FEATURE_SIZE="128" + PREPROCESSOR_OUTPUT="voxtral_preprocessor.pte" + ;; + openai/whisper-small) + MODEL_NAME="whisper" + TASK="automatic-speech-recognition" + MAX_SEQ_LEN="" + EXTRA_PIP="librosa" + PREPROCESSOR_FEATURE_SIZE="80" + PREPROCESSOR_OUTPUT="whisper_preprocessor.pte" + ;; + google/gemma-3-4b-it) + MODEL_NAME="gemma3" + TASK="multimodal-text-to-text" + MAX_SEQ_LEN="64" + EXTRA_PIP="" + PREPROCESSOR_FEATURE_SIZE="" + PREPROCESSOR_OUTPUT="" + ;; + *) + echo "Error: Unsupported model '$HF_MODEL'" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, google/gemma-3-4b-it" + exit 1 + ;; +esac + +# Determine quantization args based on quant name +case "$QUANT_NAME" in + non-quantized) + EXTRA_ARGS="" + ;; + quantized-int4-tile-packed) + EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" + ;; + quantized-int4-weight-only) + EXTRA_ARGS="--qlinear_encoder 4w" + ;; + *) + echo "Error: Unsupported quantization '$QUANT_NAME'" + echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only" + exit 1 + ;; +esac + +echo "::group::Export $MODEL_NAME" + +if [ -n "$EXTRA_PIP" ]; then + pip install $EXTRA_PIP +fi +pip list + +MAX_SEQ_LEN_ARG="" +if [ -n "$MAX_SEQ_LEN" ]; then + MAX_SEQ_LEN_ARG="--max_seq_len $MAX_SEQ_LEN" +fi +optimum-cli export executorch \ + --model "$HF_MODEL" \ + --task "$TASK" \ + --recipe "cuda" \ + --dtype bfloat16 \ + --device cuda \ + ${MAX_SEQ_LEN_ARG} \ + ${EXTRA_ARGS} \ + --output_dir ./ + +if [ -n "$PREPROCESSOR_OUTPUT" ]; then + python -m executorch.extension.audio.mel_spectrogram \ + --feature_size $PREPROCESSOR_FEATURE_SIZE \ + --stack_output \ + --max_audio_len 300 \ + --output_file $PREPROCESSOR_OUTPUT +fi + +test -f model.pte +test -f aoti_cuda_blob.ptd +if [ -n "$PREPROCESSOR_OUTPUT" ]; then + test -f $PREPROCESSOR_OUTPUT +fi +echo "::endgroup::" + +echo "::group::Store $MODEL_NAME Artifacts" +mkdir -p "${OUTPUT_DIR}" +cp model.pte "${OUTPUT_DIR}/" +cp aoti_cuda_blob.ptd "${OUTPUT_DIR}/" +if [ -n "$PREPROCESSOR_OUTPUT" ]; then + cp $PREPROCESSOR_OUTPUT "${OUTPUT_DIR}/" +fi +ls -al "${OUTPUT_DIR}" +echo "::endgroup::" diff --git a/.ci/scripts/test_model_cuda_e2e.sh b/.ci/scripts/test_model_cuda_e2e.sh new file mode 100755 index 00000000000..02845bf4b96 --- /dev/null +++ b/.ci/scripts/test_model_cuda_e2e.sh @@ -0,0 +1,207 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Test CUDA model end-to-end, need to run .ci/scripts/export_model_cuda_artifact.sh first + +show_help() { + cat << EOF +Usage: test_model_cuda_e2e.sh [model_dir] + +Build and run end-to-end tests for CUDA models. + +Arguments: + hf_model HuggingFace model ID (required) + Supported models: + - mistralai/Voxtral-Mini-3B-2507 + - openai/whisper-small + - google/gemma-3-4b-it + + quant_name Quantization type (required) + Options: + - non-quantized + - quantized-int4-tile-packed + - quantized-int4-weight-only + + model_dir Directory containing model artifacts (optional, default: current directory) + Expected files: model.pte, aoti_cuda_blob.ptd + Tokenizers and test files will be downloaded to this directory + +Examples: + test_model_cuda_e2e.sh "openai/whisper-small" "non-quantized" + test_model_cuda_e2e.sh "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed" "./model_output" +EOF +} + +if [ "${1:-}" = "-h" ] || [ "${1:-}" = "--help" ]; then + show_help + exit 0 +fi + +if [ -z "${1:-}" ]; then + echo "Error: hf_model argument is required" + echo "Run with -h or --help for usage information" + exit 1 +fi + +if [ -z "${2:-}" ]; then + echo "Error: quant_name argument is required" + echo "Run with -h or --help for usage information" + exit 1 +fi + +set -eux + +HF_MODEL="$1" +QUANT_NAME="$2" +# Download tokenizers, audio, and image files to this directory +MODEL_DIR="${3:-.}" + +echo "Testing model: $HF_MODEL (quantization: $QUANT_NAME)" + +# Make sure model.pte and aoti_cuda_blob.ptd exist +if [ ! -f "$MODEL_DIR/model.pte" ]; then + echo "Error: model.pte not found in $MODEL_DIR" + exit 1 +fi +if [ ! -f "$MODEL_DIR/aoti_cuda_blob.ptd" ]; then + echo "Error: aoti_cuda_blob.ptd not found in $MODEL_DIR" + exit 1 +fi +# Locate EXECUTORCH_ROOT from the directory of this script +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +EXECUTORCH_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)" + +pushd "$EXECUTORCH_ROOT" + +# Determine model configuration based on HF model ID +case "$HF_MODEL" in + mistralai/Voxtral-Mini-3B-2507) + MODEL_NAME="voxtral" + RUNNER_TARGET="voxtral_runner" + RUNNER_PATH="voxtral" + EXPECTED_OUTPUT="poem" + PREPROCESSOR="voxtral_preprocessor.pte" + TOKENIZER_URL="https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main" # @lint-ignore + TOKENIZER_FILE="tekken.json" + AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav" + AUDIO_FILE="poem.wav" + IMAGE_PATH="" + ;; + openai/whisper-small) + MODEL_NAME="whisper" + RUNNER_TARGET="whisper_runner" + RUNNER_PATH="whisper" + EXPECTED_OUTPUT="Mr. Quilter is the apostle of the middle classes" + PREPROCESSOR="whisper_preprocessor.pte" + TOKENIZER_URL="https://huggingface.co/openai/whisper-small/resolve/main" # @lint-ignore + TOKENIZER_FILE="" + AUDIO_URL="" + AUDIO_FILE="output.wav" + IMAGE_PATH="" + ;; + google/gemma-3-4b-it) + MODEL_NAME="gemma3" + RUNNER_TARGET="gemma3_e2e_runner" + RUNNER_PATH="gemma3" + EXPECTED_OUTPUT="chip" + PREPROCESSOR="" + TOKENIZER_URL="https://huggingface.co/unsloth/gemma-3-4b-it/resolve/main" # @lint-ignore + TOKENIZER_FILE="" + AUDIO_URL="" + AUDIO_FILE="" + IMAGE_PATH="docs/source/_static/img/et-logo.png" + ;; + *) + echo "Error: Unsupported model '$HF_MODEL'" + echo "Supported models: mistralai/Voxtral-Mini-3B-2507, openai/whisper-small, google/gemma-3-4b-it" + exit 1 + ;; +esac + +echo "::group::Setup ExecuTorch Requirements" +./install_requirements.sh +pip list +echo "::endgroup::" + +echo "::group::Prepare $MODEL_NAME Artifacts" + + +# Download tokenizer files +if [ "$TOKENIZER_FILE" != "" ]; then + curl -L $TOKENIZER_URL/$TOKENIZER_FILE -o $MODEL_DIR/$TOKENIZER_FILE +else + curl -L $TOKENIZER_URL/tokenizer.json -o $MODEL_DIR/tokenizer.json + curl -L $TOKENIZER_URL/tokenizer_config.json -o $MODEL_DIR/tokenizer_config.json + curl -L $TOKENIZER_URL/special_tokens_map.json -o $MODEL_DIR/special_tokens_map.json +fi + +# Download test files +if [ "$AUDIO_URL" != "" ]; then + curl -L $AUDIO_URL -o ${MODEL_DIR}/$AUDIO_FILE +elif [ "$MODEL_NAME" = "whisper" ]; then + conda install -y -c conda-forge "ffmpeg<8" + pip install datasets soundfile torchcodec + python -c "from datasets import load_dataset;import soundfile as sf;sample = load_dataset('distil-whisper/librispeech_long', 'clean', split='validation')[0]['audio'];sf.write('${MODEL_DIR}/$AUDIO_FILE', sample['array'][:sample['sampling_rate']*30], sample['sampling_rate'])" +fi + +ls -al +echo "::endgroup::" + +echo "::group::Build $MODEL_NAME Runner" +cmake --preset llm \ + -DEXECUTORCH_BUILD_CUDA=ON \ + -DCMAKE_INSTALL_PREFIX=cmake-out \ + -DCMAKE_BUILD_TYPE=Release \ + -Bcmake-out -S. +cmake --build cmake-out -j$(nproc) --target install --config Release + +cmake -DEXECUTORCH_BUILD_CUDA=ON \ + -DCMAKE_BUILD_TYPE=Release \ + -Sexamples/models/$RUNNER_PATH \ + -Bcmake-out/examples/models/$RUNNER_PATH/ +cmake --build cmake-out/examples/models/$RUNNER_PATH --target $RUNNER_TARGET --config Release +echo "::endgroup::" + +echo "::group::Run $MODEL_NAME Runner" +set +e +export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH + +# Build runner command with common arguments +RUNNER_BIN="cmake-out/examples/models/$RUNNER_PATH/$RUNNER_TARGET" +RUNNER_ARGS="--model_path ${MODEL_DIR}/model.pte --data_path ${MODEL_DIR}/aoti_cuda_blob.ptd --temperature 0" + +# Add model-specific arguments +case "$MODEL_NAME" in + voxtral) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/$TOKENIZER_FILE --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR" + ;; + whisper) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --audio_path ${MODEL_DIR}/$AUDIO_FILE --processor_path ${MODEL_DIR}/$PREPROCESSOR" + ;; + gemma3) + RUNNER_ARGS="$RUNNER_ARGS --tokenizer_path ${MODEL_DIR}/ --image_path $IMAGE_PATH" + ;; +esac + +OUTPUT=$($RUNNER_BIN $RUNNER_ARGS 2>&1) +EXIT_CODE=$? +set -e + +if ! echo "$OUTPUT" | grep -iq "$EXPECTED_OUTPUT"; then + echo "Expected output '$EXPECTED_OUTPUT' not found in output" + exit 1 +else + echo "Success: '$EXPECTED_OUTPUT' found in output" +fi + +if [ $EXIT_CODE -ne 0 ]; then + echo "Unexpected exit code: $EXIT_CODE" + exit $EXIT_CODE +fi +echo "::endgroup::" + +popd diff --git a/.github/workflows/cuda.yml b/.github/workflows/cuda.yml index 0b2424f3bf0..356180772c4 100644 --- a/.github/workflows/cuda.yml +++ b/.github/workflows/cuda.yml @@ -87,8 +87,8 @@ jobs: export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH PYTHON_EXECUTABLE=python source .ci/scripts/test_model.sh "${{ matrix.model }}" cmake cuda - export-voxtral-cuda-artifact: - name: export-voxtral-cuda-${{ matrix.quant.name }} + export-model-cuda-artifact: + name: export-model-cuda-artifact # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request' uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main @@ -99,17 +99,23 @@ jobs: strategy: fail-fast: false matrix: + model: + - repo: "mistralai" + name: "Voxtral-Mini-3B-2507" + - repo: "openai" + name: "whisper-small" + - repo: "google" + name: "gemma-3-4b-it" quant: - - name: "non-quantized" - artifact: "voxtral-cuda-export" - extra_args: "" - - name: "quantized-int4-tile-packed" - artifact: "voxtral-cuda-quantized-int4-tile-packed" - extra_args: "--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" - - name: "quantized-int4-weight-only" - artifact: "voxtral-cuda-quantized-int4-weight-only" - # TODO: adding "--qlinear 4w" produces invalid results. Need further investigation. - extra_args: "--qlinear_encoder 4w" + - "non-quantized" + - "quantized-int4-tile-packed" + - "quantized-int4-weight-only" + exclude: + # TODO: enable int4-weight-only on gemma3. + - model: + repo: "google" + name: "gemma-3-4b-it" + quant: "quantized-int4-weight-only" with: timeout: 90 secrets-env: EXECUTORCH_HF_TOKEN @@ -118,7 +124,7 @@ jobs: gpu-arch-version: 12.6 use-custom-docker-registry: false submodules: recursive - upload-artifact: ${{ matrix.quant.artifact }} + upload-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-cuda-${{ matrix.quant }} ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | set -eux @@ -132,128 +138,43 @@ jobs: huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} - pip install mistral-common librosa - pip list - echo "::endgroup::" - - echo "::group::Export Voxtral (${{ matrix.quant.name }})" - EXTRA_ARGS="${{ matrix.quant.extra_args }}" - optimum-cli export executorch \ - --model "mistralai/Voxtral-Mini-3B-2507" \ - --task "multimodal-text-to-text" \ - --recipe "cuda" \ - --dtype bfloat16 \ - --device cuda \ - --max_seq_len 1024 \ - ${EXTRA_ARGS} \ - --output_dir ./ - python -m executorch.extension.audio.mel_spectrogram \ - --feature_size 128 \ - --stack_output \ - --max_audio_len 300 \ - --output_file voxtral_preprocessor.pte - - test -f model.pte - test -f aoti_cuda_blob.ptd - test -f voxtral_preprocessor.pte echo "::endgroup::" - echo "::group::Store Voxtral Artifacts (${{ matrix.quant.name }})" - mkdir -p "${RUNNER_ARTIFACT_DIR}" - cp model.pte "${RUNNER_ARTIFACT_DIR}/" - cp aoti_cuda_blob.ptd "${RUNNER_ARTIFACT_DIR}/" - cp voxtral_preprocessor.pte "${RUNNER_ARTIFACT_DIR}/" - ls -al "${RUNNER_ARTIFACT_DIR}" - echo "::endgroup::" + source .ci/scripts/export_model_cuda_artifact.sh "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" - export-gemma3-cuda-artifact: - name: export-gemma3-cuda-${{ matrix.quant.name }} - # Skip this job if the pull request is from a fork (HuggingFace secrets are not available) - if: github.event.pull_request.head.repo.full_name == github.repository || github.event_name != 'pull_request' + benchmark-model-cuda: + name: benchmark-model-cuda + needs: export-model-cuda-artifact uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write contents: read - secrets: inherit strategy: fail-fast: false matrix: + model: + - repo: "mistralai" + name: "Voxtral-Mini-3B-2507" + - repo: "google" + name: "gemma-3-4b-it" quant: - - name: "non-quantized" - artifact: "gemma3-cuda-export" - extra_args: "" - - name: "quantized-int4-tile-packed" - artifact: "gemma3-cuda-quantized-int4-tile-packed" - extra_args: "--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d" + - "non-quantized" + - "quantized-int4-tile-packed" + - "quantized-int4-weight-only" + exclude: # TODO: enable int4-weight-only on gemma3. - # - name: "quantized-int4-weight-only" - # artifact: "voxtral-cuda-quantized-int4-weight-only" - # # TODO: adding "--qlinear 4w" produces invalid results. Need further investigation. - # extra_args: "--qlinear_encoder 4w" + - model: + repo: "google" + name: "gemma-3-4b-it" + quant: "quantized-int4-weight-only" with: timeout: 90 - secrets-env: EXECUTORCH_HF_TOKEN runner: linux.g5.4xlarge.nvidia.gpu gpu-arch-type: cuda gpu-arch-version: 12.6 use-custom-docker-registry: false submodules: recursive - upload-artifact: ${{ matrix.quant.artifact }} - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - script: | - set -eux - - echo "::group::Setup ExecuTorch" - ./install_executorch.sh - echo "::endgroup::" - - echo "::group::Setup Huggingface" - pip install -U "huggingface_hub[cli]<1.0" accelerate - huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN - OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt) - pip install git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION} - pip list - echo "::endgroup::" - - echo "::group::Export Gemma3 (${{ matrix.quant.name }})" - EXTRA_ARGS="${{ matrix.quant.extra_args }}" - optimum-cli export executorch \ - --model "google/gemma-3-4b-it" \ - --task "multimodal-text-to-text" \ - --recipe "cuda" \ - --dtype bfloat16 \ - --device cuda \ - --max_seq_len 64 \ - --output_dir ./ - - test -f model.pte - test -f aoti_cuda_blob.ptd - echo "::endgroup::" - - echo "::group::Store Gemma3 Artifacts (${{ matrix.quant.name }})" - mkdir -p "${RUNNER_ARTIFACT_DIR}/" - cp model.pte "${RUNNER_ARTIFACT_DIR}/" - cp aoti_cuda_blob.ptd "${RUNNER_ARTIFACT_DIR}/" - ls -al "${RUNNER_ARTIFACT_DIR}/" - echo "::endgroup::" - - benchmark-voxtral-cuda: - name: benchmark-voxtral-cuda - needs: export-voxtral-cuda-artifact - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - strategy: - fail-fast: false - with: - timeout: 90 - runner: linux.g5.4xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: 12.6 - use-custom-docker-registry: false - submodules: recursive - download-artifact: voxtral-cuda-export + download-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-cuda-${{ matrix.quant }} ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | set -eux @@ -263,13 +184,13 @@ jobs: pip list echo "::endgroup::" - echo "::group::Prepare Voxtral Artifacts" + echo "::group::Prepare ${{ matrix.model }} Artifacts" cp "${RUNNER_ARTIFACT_DIR}/model.pte" . cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" . ls -al model.pte aoti_cuda_blob.ptd echo "::endgroup::" - echo "::group::Build Voxtral Benchmark" + echo "::group::Build ${{ matrix.model }} Benchmark" cmake -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_CUDA=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ @@ -277,158 +198,19 @@ jobs: -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ -DEXECUTORCH_BUILD_TESTS=ON \ -Bcmake-out . - cmake --build cmake-out -j$(( $(nproc) - 1 )) --target multimodal_benchmark + cmake --build cmake-out -j$(nproc) --target multimodal_benchmark echo "::endgroup::" - echo "::group::Run Voxtral Benchmark" + echo "::group::Run ${{ matrix.model.name }} Benchmark" export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH - cmake-out/backends/cuda/multimodal_benchmark voxtral model.pte aoti_cuda_blob.ptd + cmake-out/backends/cuda/multimodal_benchmark ${{ matrix.model.name }} model.pte aoti_cuda_blob.ptd echo "::endgroup::" - benchmark-gemma3-cuda: - name: benchmark-gemma3-cuda - needs: export-gemma3-cuda-artifact - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - strategy: - fail-fast: false - with: - timeout: 90 - runner: linux.g5.4xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: 12.6 - use-custom-docker-registry: false - submodules: recursive - download-artifact: gemma3-cuda-export - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - script: | - set -eux - - echo "::group::Setup ExecuTorch Requirements" - ./install_requirements.sh - pip list - echo "::endgroup::" - - echo "::group::Prepare Gemma3 Artifacts" - cp "${RUNNER_ARTIFACT_DIR}/model.pte" . - cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" . - ls -al model.pte aoti_cuda_blob.ptd - echo "::endgroup::" - - echo "::group::Build Gemma3 Benchmark" - cmake -DCMAKE_BUILD_TYPE=Release \ - -DEXECUTORCH_BUILD_CUDA=ON \ - -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ - -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ - -DEXECUTORCH_BUILD_EXTENSION_NAMED_DATA_MAP=ON \ - -DEXECUTORCH_BUILD_TESTS=ON \ - -Bcmake-out . - cmake --build cmake-out -j$(( $(nproc) - 1 )) --target multimodal_benchmark - echo "::endgroup::" - - echo "::group::Run Gemma3 Benchmark" - - export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH - cmake-out/backends/cuda/multimodal_benchmark gemma3 model.pte aoti_cuda_blob.ptd - - echo "::endgroup::" - - test-voxtral-cuda-e2e: - name: test-voxtral-cuda-e2e-${{ matrix.format.name }} - needs: export-voxtral-cuda-artifact - uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main - permissions: - id-token: write - contents: read - strategy: - fail-fast: false - matrix: - format: - - name: "non-quantized" - artifact: "voxtral-cuda-export" - - name: "quantized-int4-tile-packed" - artifact: "voxtral-cuda-quantized-int4-tile-packed" - - name: "quantized-int4-weight-only" - artifact: "voxtral-cuda-quantized-int4-weight-only" - with: - timeout: 90 - runner: linux.g5.4xlarge.nvidia.gpu - gpu-arch-type: cuda - gpu-arch-version: 12.6 - use-custom-docker-registry: false - submodules: recursive - download-artifact: ${{ matrix.format.artifact }} - ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} - script: | - set -eux - - echo "::group::Setup ExecuTorch Requirements" - ./install_requirements.sh - pip list - echo "::endgroup::" - - echo "::group::Prepare Voxtral Artifacts (${{ matrix.format.name }})" - cp "${RUNNER_ARTIFACT_DIR}/model.pte" . - cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" . - cp "${RUNNER_ARTIFACT_DIR}/voxtral_preprocessor.pte" . - TOKENIZER_URL="https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main/tekken.json" - curl -L $TOKENIZER_URL -o tekken.json - ls -al model.pte aoti_cuda_blob.ptd voxtral_preprocessor.pte tekken.json - echo "::endgroup::" - - echo "::group::Download Test Audio File" - AUDIO_URL="https://github.com/voxserv/audio_quality_testing_samples/raw/refs/heads/master/testaudio/16000/test01_20s.wav" - curl -L $AUDIO_URL -o poem.wav - echo "::endgroup::" - - echo "::group::Build Voxtral Runner" - cmake --preset llm \ - -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out -S. - cmake --build cmake-out -j$(( $(nproc) - 1 )) --target install --config Release - - cmake -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -Sexamples/models/voxtral \ - -Bcmake-out/examples/models/voxtral/ - cmake --build cmake-out/examples/models/voxtral --target voxtral_runner --config Release - echo "::endgroup::" - - echo "::group::Run Voxtral Runner (${{ matrix.format.name }})" - set +e - export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH - OUTPUT=$(cmake-out/examples/models/voxtral/voxtral_runner \ - --model_path model.pte \ - --data_path aoti_cuda_blob.ptd \ - --tokenizer_path tekken.json \ - --audio_path poem.wav \ - --processor_path voxtral_preprocessor.pte \ - --temperature 0 2>&1) - EXIT_CODE=$? - set -e - - echo "$OUTPUT" - - if ! echo "$OUTPUT" | grep -iq "poem"; then - echo "Expected output 'poem' not found in output" - exit 1 - fi - - if [ $EXIT_CODE -ne 0 ]; then - echo "Unexpected exit code: $EXIT_CODE" - exit $EXIT_CODE - fi - echo "::endgroup::" - - test-gemma3-cuda-e2e: - name: test-gemma3-cuda-e2e-${{ matrix.format.name }} - needs: export-gemma3-cuda-artifact + test-model-cuda-e2e: + name: test-model-cuda-e2e + needs: export-model-cuda-artifact uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main permissions: id-token: write @@ -436,14 +218,23 @@ jobs: strategy: fail-fast: false matrix: - format: - - name: "non-quantized" - artifact: "gemma3-cuda-export" - - name: "quantized-int4-tile-packed" - artifact: "gemma3-cuda-quantized-int4-tile-packed" + model: + - repo: "mistralai" + name: "Voxtral-Mini-3B-2507" + - repo: "openai" + name: "whisper-small" + - repo: "google" + name: "gemma-3-4b-it" + quant: + - "non-quantized" + - "quantized-int4-tile-packed" + - "quantized-int4-weight-only" + exclude: # TODO: enable int4-weight-only on gemma3. - # - name: "quantized-int4-weight-only" - # artifact: "gemma3-cuda-quantized-int4-weight-only" + - model: + repo: "google" + name: "gemma-3-4b-it" + quant: "quantized-int4-weight-only" with: timeout: 90 runner: linux.g5.4xlarge.nvidia.gpu @@ -451,61 +242,7 @@ jobs: gpu-arch-version: 12.6 use-custom-docker-registry: false submodules: recursive - download-artifact: ${{ matrix.format.artifact }} + download-artifact: ${{ matrix.model.repo }}-${{ matrix.model.name }}-cuda-${{ matrix.quant }} ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} script: | - set -eux - - echo "::group::Setup ExecuTorch Requirements" - ./install_requirements.sh - pip list - echo "::endgroup::" - - echo "::group::Prepare Gemma3 Artifacts (${{ matrix.format.name }})" - cp "${RUNNER_ARTIFACT_DIR}/model.pte" . - cp "${RUNNER_ARTIFACT_DIR}/aoti_cuda_blob.ptd" . - TOKENIZER_URL="https://huggingface.co/unsloth/gemma-3-1b-it/resolve/main/tokenizer.json" - curl -L $TOKENIZER_URL -o tokenizer.json - ls -al model.pte aoti_cuda_blob.ptd tokenizer.json - IMAGE_PATH="docs/source/_static/img/et-logo.png" - echo "::endgroup::" - - echo "::group::Build Gemma3 Runner" - cmake --preset llm \ - -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_INSTALL_PREFIX=cmake-out \ - -DCMAKE_BUILD_TYPE=Release \ - -Bcmake-out -S. - cmake --build cmake-out -j$(( $(nproc) - 1 )) --target install --config Release - - cmake -DEXECUTORCH_BUILD_CUDA=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -Sexamples/models/gemma3 \ - -Bcmake-out/examples/models/gemma3/ - cmake --build cmake-out/examples/models/gemma3 --target gemma3_e2e_runner --config Release - echo "::endgroup::" - - echo "::group::Run Gemma3 Runner (${{ matrix.format.name }})" - set +e - export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH - OUTPUT=$(cmake-out/examples/models/gemma3/gemma3_e2e_runner \ - --model_path model.pte \ - --data_path aoti_cuda_blob.ptd \ - --tokenizer_path tokenizer.json \ - --image_path $IMAGE_PATH \ - --temperature 0 2>&1) - EXIT_CODE=$? - set -e - - echo "$OUTPUT" - - if ! echo "$OUTPUT" | grep -iq "chip"; then - echo "Expected output 'chip' not found in output" - exit 1 - fi - - if [ $EXIT_CODE -ne 0 ]; then - echo "Unexpected exit code: $EXIT_CODE" - exit $EXIT_CODE - fi - echo "::endgroup::" + source .ci/scripts/test_model_cuda_e2e.sh "${{ matrix.model.repo }}/${{ matrix.model.name }}" "${{ matrix.quant }}" "${RUNNER_ARTIFACT_DIR}" diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b96c12fbf3..c6d6f26b41f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -926,6 +926,11 @@ if(EXECUTORCH_BUILD_EXTENSION_LLM_RUNNER) list(APPEND _executorch_extensions extension_llm_runner) endif() +if(EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/asr/runner) + list(APPEND _executorch_extensions extension_asr_runner) +endif() + if(EXECUTORCH_BUILD_EXTENSION_LLM_APPLE) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/extension/llm/apple) endif() diff --git a/backends/cuda/tests/multimodal_benchmark.cpp b/backends/cuda/tests/multimodal_benchmark.cpp index 679db889b71..7365d0b7ba8 100644 --- a/backends/cuda/tests/multimodal_benchmark.cpp +++ b/backends/cuda/tests/multimodal_benchmark.cpp @@ -60,7 +60,8 @@ ModelType parse_model_type(const std::string& model_name) { lower_name.begin(), [](unsigned char c) { return std::tolower(c); }); - if (lower_name.find("gemma3") != std::string::npos) { + if (lower_name.find("gemma3") != std::string::npos || + lower_name.find("gemma-3") != std::string::npos) { return ModelType::GEMMA3; } else if (lower_name.find("voxtral") != std::string::npos) { return ModelType::VOXTRAL; diff --git a/examples/models/whisper/CMakeLists.txt b/examples/models/whisper/CMakeLists.txt new file mode 100644 index 00000000000..70f5892baa7 --- /dev/null +++ b/examples/models/whisper/CMakeLists.txt @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.29) +project(whisper_runner) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT "${CMAKE_CURRENT_SOURCE_DIR}/../../..") +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +# Let files say "include " +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +# Need this for gflags for some reason +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) + +set(_link_libraries executorch gflags) +set(_srcs multimodal.cpp) + +list( + APPEND + _link_libraries + optimized_native_cpu_ops_lib + quantized_ops_lib + custom_ops + cpublas + eigen_blas +) + +# XNNPACK +if(TARGET xnnpack_backend) + list(APPEND _link_libraries xnnpack_backend) +endif() + +# Add LLM runner and extension module +if(NOT TARGET extension_asr_runner) + message( + FATAL_ERROR + "ExecuTorch must be installed with EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER enabled." + ) +endif() + +# Needed for cpuinfo where it uses android specific log lib +if(ANDROID) + list(APPEND _link_libraries log) +endif() + +# Add the required ExecuTorch extensions for multimodal LLM runner +list( + APPEND + _link_libraries + extension_asr_runner + extension_llm_runner # Needed for load_tokenizer() + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +# Link CUDA backend +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND _link_libraries aoti_cuda) + executorch_target_link_options_shared_lib(aoti_cuda) +endif() + +if(EXECUTORCH_BUILD_METAL) + list(APPEND _link_libraries metal_backend) + executorch_target_link_options_shared_lib(metal_backend) +endif() + +# Add tokenizers +list(APPEND _link_libraries tokenizers::tokenizers) + +add_executable(whisper_runner main.cpp) + +target_include_directories(whisper_runner PUBLIC ${_common_include_directories}) + +target_link_libraries(whisper_runner PUBLIC ${_link_libraries}) +target_compile_options(whisper_runner PUBLIC ${_common_compile_options}) diff --git a/examples/models/whisper/README.md b/examples/models/whisper/README.md new file mode 100644 index 00000000000..a4025441f7e --- /dev/null +++ b/examples/models/whisper/README.md @@ -0,0 +1,126 @@ +# Whisper Runner + +This directory hosts a lightweight C++ helper that drives Whisper models +exported to ExecuTorch. The `AsrRunner` owns the `Module` instance that +wraps a bundled `.pte` program and optional `.ptd` weight file, loads the +`encoder` and `text_decoder` methods, and exposes a `transcribe()` loop that +streams decoded text pieces through a callback. + +The runner assumes: +- `model.pte` contains both Whisper encoder and decoder entry points named + `encoder` and `text_decoder`. +- (Optional) Depending on export configurations, model weights can be optionally stored in a companion + `model.ptd`. The runner will load the file if present. +- A tokenizer JSON compatible with the ExecuTorch tokenizers shim is available. + +Audio preprocessing is not part of the runner itself. To transform raw audio +into the mel features expected by the encoder, reuse the pattern in +`examples/models/voxtral/multimodal.cpp`, which loads a `preprocessor.pte` +module to generate the spectrogram tensor. + +## Build + +Currently we have CUDA build support only. CPU and Metal backend builds are WIP. + +```bash +# Install ExecuTorch libraries: +cmake --preset llm -DEXECUTORCH_BUILD_CUDA=ON -DCMAKE_INSTALL_PREFIX=cmake-out -DCMAKE_BUILD_TYPE=Release . -Bcmake-out +cmake --build cmake-out -j$(nproc) --target install --config Release + +# Build the runner: +cmake \ + -B cmake-out/examples/models/whisper \ + -S examples/models/whisper +cmake --build cmake-out/examples/models/whisper -j$(nproc) +``` + +The first cmake command build produces a static library named `extension_asr_runner`. The second cmake command links it into your +application together with the standard ExecuTorch runtime libraries and the +tokenizer target (`tokenizers::tokenizers`). + +## Usage + +### Export Whisper Model + +Use [Optimum-ExecuTorch](https://github.com/huggingface/optimum-executorch) to export a Whisper model from Hugging Face: + +```bash +optimum-cli export executorch \ + --model openai/whisper-small \ + --task automatic-speech-recognition \ + --recipe cuda \ + --dtype bfloat16 \ + --device cuda \ + --output_dir ./ +``` + +This command generates: +- `model.pte` — Compiled Whisper model +- `aoti_cuda_blob.ptd` — Weight data file for CUDA backend + +Export a preprocessor to convert raw audio to mel-spectrograms: + +```bash +python -m executorch.extension.audio.mel_spectrogram \ + --feature_size 80 \ + --stack_output \ + --max_audio_len 300 \ + --output_file whisper_preprocessor.pte +``` + +### Quantization + +Export quantized models to reduce size and improve performance: + +```bash +# 4-bit tile packed quantization for encoder +optimum-cli export executorch \ + --model openai/whisper-small \ + --task automatic-speech-recognition \ + --recipe cuda \ + --dtype bfloat16 \ + --device cuda \ + --qlinear 4w \ + --qlinear_encoder 4w \ + --qlinear_packing_format tile_packed_to_4d \ + --qlinear_encoder_packing_format tile_packed_to_4d \ + --output_dir ./ +``` + + +### Download Tokenizer + +Download the tokenizer files required for inference: + +```bash +curl -L https://huggingface.co/openai/whisper-small/resolve/main/tokenizer.json -o tokenizer.json +curl -L https://huggingface.co/openai/whisper-small/resolve/main/tokenizer_config.json -o tokenizer_config.json +curl -L https://huggingface.co/openai/whisper-small/resolve/main/special_tokens_map.json -o special_tokens_map.json +``` + +### Prepare Audio + +Generate test audio or use an existing WAV file. The model expects 16kHz mono audio. + +```bash +# Generate sample audio using librispeech dataset +python -c "from datasets import load_dataset; import soundfile as sf; sample = load_dataset('distil-whisper/librispeech_long', 'clean', split='validation')[0]['audio']; sf.write('output.wav', sample['array'][:sample['sampling_rate']*30], sample['sampling_rate'])" +``` + +### Run Inference + +After building the runner (see [Build](#build) section), execute it with the exported model and audio: + +```bash +# Set library path for CUDA dependencies +export LD_LIBRARY_PATH=/opt/conda/lib:$LD_LIBRARY_PATH + +# Run the Whisper runner +cmake-out/examples/models/whisper/whisper_runner \ + --model_path model.pte \ + --data_path aoti_cuda_blob.ptd \ + --tokenizer_path ./ \ + --audio_path output.wav \ + --processor_path whisper_preprocessor.pte \ + --temperature 0 +``` diff --git a/examples/models/whisper/main.cpp b/examples/models/whisper/main.cpp new file mode 100644 index 00000000000..b4462e2c39a --- /dev/null +++ b/examples/models/whisper/main.cpp @@ -0,0 +1,126 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +DEFINE_string(model_path, "model.pte", "Path to Whisper model (.pte)."); +DEFINE_string(data_path, "", "Optional path to Whisper weights (.ptd)."); +DEFINE_string( + tokenizer_path, + ".", + "Path to tokenizer directory containing tokenizer.json, tokenizer_config.json, and special_tokens_map.json."); +DEFINE_string( + processor_path, + "", + "Path to preprocessor .pte for converting raw audio."); +DEFINE_string( + audio_path, + "", + "Path to input audio file. Accepts .wav or raw float .bin."); +DEFINE_double( + temperature, + 0.0, + "Sampling temperature. 0.0 performs greedy decoding."); +DEFINE_int32(max_new_tokens, 128, "Maximum number of tokens to generate."); + +using ::executorch::extension::from_blob; +using ::executorch::extension::Module; + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + ::executorch::extension::TensorPtr features; + std::vector audio_data; + std::unique_ptr processor; + + if (FLAGS_audio_path.empty()) { + ET_LOG(Error, "audio_path flag must be provided."); + return 1; + } + + audio_data = + executorch::extension::llm::load_wav_audio_data(FLAGS_audio_path); + ET_LOG( + Info, + "First 2 values of audio data: %f, %f", + audio_data[0], + audio_data[1]); + + processor = + std::make_unique(FLAGS_processor_path, Module::LoadMode::Mmap); + auto load_error = processor->load(); + if (load_error != ::executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to load preprocessor module."); + return 1; + } + + auto audio_tensor = from_blob( + audio_data.data(), + {static_cast<::executorch::aten::SizesType>(audio_data.size())}, + ::executorch::aten::ScalarType::Float); + + auto processed_result = processor->execute("forward", audio_tensor); + if (processed_result.error() != ::executorch::runtime::Error::Ok) { + ET_LOG(Error, "Audio preprocessing failed."); + return 1; + } + auto outputs = std::move(processed_result.get()); + if (outputs.empty() || !outputs[0].isTensor()) { + ET_LOG(Error, "Preprocessor returned unexpected outputs."); + return 1; + } + auto tensor = outputs[0].toTensor(); + ET_LOG( + Info, + "Result scalar_type: %s, first value %f", + ::executorch::runtime::toString(tensor.scalar_type()), + tensor.mutable_data_ptr()[0]); + features = std::make_shared<::executorch::aten::Tensor>(std::move(tensor)); + + executorch::extension::asr::AsrRunner runner( + FLAGS_model_path, FLAGS_data_path, FLAGS_tokenizer_path); + auto load_err = runner.load(); + if (load_err != ::executorch::runtime::Error::Ok) { + ET_LOG(Error, "Failed to load Whisper model."); + return 1; + } + + executorch::extension::asr::AsrTranscribeConfig config; + config.max_new_tokens = FLAGS_max_new_tokens; + config.temperature = static_cast(FLAGS_temperature); + config.decoder_start_token_id = 50257; + + auto result = + runner.transcribe(features, config, [&](const std::string& piece) { + ::executorch::extension::llm::safe_printf(piece.c_str()); + fflush(stdout); + }); + + if (!result.ok()) { + ET_LOG(Error, "Transcription failed."); + return 1; + } + + return 0; +} diff --git a/extension/asr/runner/CMakeLists.txt b/extension/asr/runner/CMakeLists.txt new file mode 100644 index 00000000000..cc9ba01596a --- /dev/null +++ b/extension/asr/runner/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# ASR runner for models like Whisper +# +# ### Editing this file ### +# +# This file should be formatted with +# ~~~ +# cmake-format -i CMakeLists.txt +# ~~~ +# It should also be cmake-lint clean. +# + +if(NOT EXECUTORCH_ROOT) + set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +endif() + +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +set(runner_deps executorch_core extension_module extension_tensor + tokenizers::tokenizers +) + +# Define runner library +add_library(extension_asr_runner STATIC runner.cpp) +target_include_directories( + extension_asr_runner INTERFACE ${_common_include_directories} +) +target_link_libraries(extension_asr_runner PUBLIC ${runner_deps}) +set_target_properties( + extension_asr_runner PROPERTIES POSITION_INDEPENDENT_CODE ON +) + +install( + TARGETS extension_asr_runner + EXPORT ExecuTorchTargets + DESTINATION ${CMAKE_INSTALL_LIBDIR} + INCLUDES + DESTINATION ${_common_include_directories} +) + +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/ + DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/executorch/extension/asr/runner + FILES_MATCHING + PATTERN "*.h" +) diff --git a/extension/asr/runner/runner.cpp b/extension/asr/runner/runner.cpp new file mode 100644 index 00000000000..6bbb44e4faa --- /dev/null +++ b/extension/asr/runner/runner.cpp @@ -0,0 +1,321 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::asr { +namespace { + +constexpr const char* kEncoderMethodName = "encoder"; +constexpr const char* kDecoderMethodName = "text_decoder"; + +} // namespace + +AsrRunner::AsrRunner( + const std::string& module_path, + std::optional data_path, + const std::string& tokenizer_path) + : module_path_(module_path), + data_path_(data_path.value_or("")), + tokenizer_path_(tokenizer_path) { + if (data_path_.empty()) { + module_ = std::make_unique(module_path_, Module::LoadMode::Mmap); + } else { + module_ = std::make_unique( + module_path_, data_path_, Module::LoadMode::Mmap); + } +} + +bool AsrRunner::is_loaded() const { + return module_ && encoder_method_loaded_ && decoder_method_loaded_ && + tokenizer_ && tokenizer_->is_loaded() && !eos_token_ids_.empty(); +} + +Error AsrRunner::load_tokenizer() { + if (tokenizer_ && tokenizer_->is_loaded()) { + return Error::Ok; + } + + auto tokenizer = + ::executorch::extension::llm::load_tokenizer(tokenizer_path_); + ET_CHECK_OR_RETURN_ERROR( + tokenizer, + Internal, + "Failed to create tokenizer from %s", + tokenizer_path_.c_str()); + + tokenizer_ = std::move(tokenizer); + if (!tokenizer_->is_loaded()) { + ET_LOG( + Error, + "Tokenizer reported unloaded state after load: %s", + tokenizer_path_.c_str()); + return Error::Internal; + } + + eos_token_ids_.clear(); + eos_token_ids_.insert(static_cast(tokenizer_->eos_tok())); + return Error::Ok; +} + +Error AsrRunner::load() { + if (is_loaded()) { + return Error::Ok; + } + + stats_.model_load_start_ms = ::executorch::extension::llm::time_in_ms(); + + ET_CHECK_OR_RETURN_ERROR( + module_ != nullptr, + InvalidArgument, + "Module handle is null for path %s", + module_path_.c_str()); + + ET_CHECK_OK_OR_RETURN_ERROR(module_->load()); + + auto method_names_result = module_->method_names(); + ET_CHECK_OK_OR_RETURN_ERROR(method_names_result.error()); + const auto& method_names = method_names_result.get(); + + ET_CHECK_OR_RETURN_ERROR( + method_names.count(kEncoderMethodName) && + method_names.count(kDecoderMethodName), + InvalidArgument, + "Required methods not found. encoder=%d, text_decoder=%d", + static_cast(method_names.count(kEncoderMethodName)), + static_cast(method_names.count(kDecoderMethodName))); + + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kEncoderMethodName)); + encoder_method_loaded_ = true; + + ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName)); + decoder_method_loaded_ = true; + + ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer()); + auto eos_ids = get_eos_ids(tokenizer_.get(), module_.get()); + if (!eos_ids.empty()) { + eos_token_ids_.clear(); + for (uint64_t eos_id : eos_ids) { + eos_token_ids_.insert(static_cast(eos_id)); + } + } + + stats_.model_load_end_ms = ::executorch::extension::llm::time_in_ms(); + + return Error::Ok; +} + +Result> AsrRunner::transcribe( + ::executorch::extension::TensorPtr preprocessed_features, + AsrTranscribeConfig config, + std::function token_callback) { + ET_CHECK_OR_RETURN_ERROR( + config.max_new_tokens > 0, + InvalidArgument, + "max_new_tokens must be positive, got %" PRId64, + config.max_new_tokens); + + ET_LOG( + Info, + "Preprocessed features shape: [%zu, %zu, %zu]", + static_cast(preprocessed_features->size(0)), + static_cast(preprocessed_features->size(1)), + static_cast(preprocessed_features->size(2))); + + if (!is_loaded()) { + ET_CHECK_OK_OR_RETURN_ERROR(load()); + } + + ET_LOG( + Info, + "RSS after loading model: %f MiB (0 if unsupported)", + ::executorch::extension::llm::get_rss_bytes() / 1024.0 / 1024.0); + + // Reset internal state and start inference + stats_.inference_start_ms = ::executorch::extension::llm::time_in_ms(); + + const std::unordered_set* eos_tokens = &eos_token_ids(); + if (!config.eos_token_ids.empty()) { + eos_tokens = &config.eos_token_ids; + } + ET_CHECK_OR_RETURN_ERROR( + !eos_tokens->empty(), + InvalidArgument, + "EOS token set must not be empty."); + ::executorch::extension::llm::Sampler sampler( + tokenizer_->vocab_size(), config.temperature); + + // Check expected dtype for encoder input + auto encoder_method_meta_result = module_->method_meta(kEncoderMethodName); + ET_CHECK_OK_OR_RETURN_ERROR(encoder_method_meta_result.error()); + auto encoder_method_meta = encoder_method_meta_result.get(); + + ::executorch::aten::ScalarType expected_dtype = + ::executorch::aten::ScalarType::Float; + if (encoder_method_meta.num_inputs() > 0) { + auto input_meta_result = encoder_method_meta.input_tensor_meta(0); + if (input_meta_result.error() == ::executorch::runtime::Error::Ok) { + expected_dtype = input_meta_result.get().scalar_type(); + } + } + + // Convert preprocessed_features to expected dtype if needed + if (preprocessed_features->scalar_type() != expected_dtype) { + if (expected_dtype == ::executorch::aten::ScalarType::BFloat16) { + ET_LOG( + Info, + "Converting audio features from %s to BFloat16. Before converting, first value = %f", + ::executorch::runtime::toString(preprocessed_features->scalar_type()), + preprocessed_features->mutable_data_ptr()[0]); + auto convert_result = ::executorch::extension::llm::convert_to_bfloat16( + preprocessed_features); + ET_CHECK_OK_OR_RETURN_ERROR(convert_result.error()); + preprocessed_features = convert_result.get(); + ET_LOG( + Info, + "Conversion complete, first value = %f", + static_cast( + preprocessed_features + ->mutable_data_ptr<::executorch::aten::BFloat16>()[0])); + } + } + + auto encoder_result = + module_->execute(kEncoderMethodName, preprocessed_features); + ET_CHECK_OK_OR_RETURN_ERROR(encoder_result.error()); + + stats_.prompt_eval_end_ms = ::executorch::extension::llm::time_in_ms(); + stats_.num_prompt_tokens = 0; + + auto encoder_outputs = std::move(*encoder_result); + ET_CHECK_OR_RETURN_ERROR( + encoder_outputs.size() == 1 && encoder_outputs[0].isTensor(), + Internal, + "Encoder returned %zu outputs; expected a single tensor.", + encoder_outputs.size()); + + ::executorch::aten::Tensor encoder_output_tensor = + std::move(encoder_outputs[0]).toTensor(); + + ET_LOG( + Info, + "Encoder output shape: [%zu, %zu, %zu]", + static_cast(encoder_output_tensor.size(0)), + static_cast(encoder_output_tensor.size(1)), + static_cast(encoder_output_tensor.size(2))); + ET_LOG( + Info, + "Encoder first value: %f", + static_cast( + encoder_output_tensor + .mutable_data_ptr<::executorch::aten::BFloat16>()[0])); + + auto encoder_output_ptr = std::make_shared<::executorch::aten::Tensor>( + std::move(encoder_output_tensor)); + + std::vector tokens = {config.decoder_start_token_id}; + + int64_t input_id = config.decoder_start_token_id; + int64_t cache_position = 0; + int64_t generated_tokens = 0; + bool first_token_generated = false; + auto decoder_input_ptr = ::executorch::extension::from_blob( + &input_id, + {static_cast<::executorch::aten::SizesType>(1), + static_cast<::executorch::aten::SizesType>(1)}, + ::executorch::aten::ScalarType::Long); + + auto cache_position_ptr = ::executorch::extension::from_blob( + &cache_position, + {static_cast<::executorch::aten::SizesType>(1)}, + ::executorch::aten::ScalarType::Long); + + std::vector<::executorch::runtime::EValue> decoder_inputs; + decoder_inputs.reserve(3); + decoder_inputs.emplace_back(decoder_input_ptr); + decoder_inputs.emplace_back(encoder_output_ptr); + decoder_inputs.emplace_back(cache_position_ptr); + // Add some green coloring for the first generated token + // token_callback("\033[1;32m"); + while (generated_tokens < config.max_new_tokens) { + input_id = tokens.back(); + auto decoder_result = module_->execute(kDecoderMethodName, decoder_inputs); + ET_CHECK_OK_OR_RETURN_ERROR(decoder_result.error()); + + auto decoder_outputs = std::move(*decoder_result); + ET_CHECK_OR_RETURN_ERROR( + decoder_outputs.size() == 1 && decoder_outputs[0].isTensor(), + Internal, + "Decoder returned %zu outputs; expected a single tensor.", + decoder_outputs.size()); + + ::executorch::aten::Tensor logits_tensor = + std::move(decoder_outputs[0]).toTensor(); + const int64_t vocab_size = logits_tensor.numel(); + ET_CHECK_OR_RETURN_ERROR( + vocab_size > 0, Internal, "Decoder logits tensor is empty."); + + const int64_t next_token = + static_cast(::executorch::extension::llm::logits_to_token( + logits_tensor, config.temperature)); + + if (!first_token_generated) { + stats_.first_token_ms = ::executorch::extension::llm::time_in_ms(); + first_token_generated = true; + } + + const int64_t prev_token = input_id; + tokens.push_back(next_token); + ++generated_tokens; + ++cache_position; + input_id = next_token; + + if (token_callback) { + auto piece_result = tokenizer_->decode( + static_cast(prev_token), static_cast(next_token)); + if (piece_result.ok()) { + token_callback(piece_result.get()); + } else { + ET_LOG( + Error, + "Tokenizer failed to decode token pair (%" PRId64 ", %" PRId64 + ") with error %d", + prev_token, + next_token, + static_cast(piece_result.error())); + } + } + + if (eos_tokens->count(next_token) > 0) { + break; + } + } + // Reset coloring + // token_callback("\033[0m"); + // Update stats and print report + stats_.num_generated_tokens = generated_tokens; + stats_.inference_end_ms = ::executorch::extension::llm::time_in_ms(); + printf("\n"); + print_report(stats_); + + return tokens; +} + +} // namespace executorch::extension::asr diff --git a/extension/asr/runner/runner.h b/extension/asr/runner/runner.h new file mode 100644 index 00000000000..a9f8ce3edda --- /dev/null +++ b/extension/asr/runner/runner.h @@ -0,0 +1,115 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace executorch::extension::asr { + +using ::executorch::extension::Module; +using ::executorch::extension::llm::get_eos_ids; +using ::executorch::extension::llm::load_tokenizer; +using ::executorch::extension::llm::print_report; +using ::executorch::extension::llm::Sampler; +using ::executorch::extension::llm::Stats; +using ::executorch::runtime::Error; +using ::executorch::runtime::Result; + +/** + * Configuration for the ASR transcription loop. + * + * max_new_tokens controls the number of tokens generated after the prompt. + * Temperature controls the randomness of the output. + */ +struct ET_EXPERIMENTAL AsrTranscribeConfig { + int64_t max_new_tokens = 128; + std::unordered_set eos_token_ids = {}; + float temperature = 0.0f; + int64_t decoder_start_token_id = 0; +}; + +/** + * Runner that owns a ASR model encoder + decoder pair exported as a single + * ExecuTorch module. A good example is Whisper + * (https://huggingface.co/openai/whisper-small) + * + * The module is expected to expose two callable methods: + * - "encoder": processes precomputed audio features into encoder states. + * - "text_decoder": consumes the decoder input ids, encoder output and cache + * positions to autoregressively generate logits. + */ +class ET_EXPERIMENTAL AsrRunner { + public: + AsrRunner( + const std::string& module_path, + std::optional data_path, + const std::string& tokenizer_path); + + /** + * Returns true when the module and tokenizer are ready for inference. + */ + bool is_loaded() const; + + /** + * Loads the module, validates required methods and initialises tokenizer. + */ + ::executorch::runtime::Error load(); + + /** + * Executes an end-to-end transcription cycle. + * + * @param preprocessed_features Audio features tensor of shape [batch, time, + * features] already processed by a preprocessor module. Typically produced + * by an audio feature extractor (e.g., mel-spectrogram computation). + * @param config Controls generation length and termination criteria. + * @param token_callback Optional functor invoked for each decoded piece of + * text emitted during generation. + * + * @returns Result containing the final decoder token ids (including the seed + * prompt and generated tokens), or an error. + */ + ::executorch::runtime::Result> transcribe( + ::executorch::extension::TensorPtr preprocessed_features, + AsrTranscribeConfig config = {}, + std::function token_callback = {}); + + private: + ::executorch::runtime::Error load_tokenizer(); + inline const std::unordered_set& eos_token_ids() const { + return eos_token_ids_; + } + + std::string module_path_; + std::string data_path_; + std::string tokenizer_path_; + + std::unique_ptr module_; + std::unique_ptr<::tokenizers::Tokenizer> tokenizer_; + std::unordered_set eos_token_ids_; + + bool encoder_method_loaded_ = false; + bool decoder_method_loaded_ = false; + + Stats stats_; +}; + +} // namespace executorch::extension::asr diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 9a090de50d6..720000185c9 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -12,6 +12,7 @@ #include #include +#include namespace executorch { namespace extension { @@ -64,39 +65,8 @@ class ET_EXPERIMENTAL TextDecoderRunner { inline int32_t logits_to_token( const executorch::aten::Tensor& logits_tensor, const float temperature = 0.0f) { - int32_t result = 0; - - // Create a minimal context for error handling in ET_SWITCH - struct { - [[noreturn]] void fail(torch::executor::Error /* error */) { - ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token"); - } - } ctx; - - ET_SWITCH_FOUR_TYPES( - Float, - Half, - BFloat16, - UInt16, - logits_tensor.scalar_type(), - ctx, - "logits_to_token", - CTYPE, - [&]() { - // If the logit_tensor rank is 3, the shape is [batch, seq_length, - // vocab_size], get the last logits, sample and return. Else the model - // outputs the last logit, directly sample and return. - auto* logits = logits_tensor.mutable_data_ptr(); - ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); - if (logits_tensor.dim() == 3) { - auto num_tokens = logits_tensor.size(1); - logits += (num_tokens - 1) * vocab_size; - } - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - Sampler sampler(vocab_size, temperature); - result = sampler.sample(logits); - }); - return result; + return ::executorch::extension::llm::logits_to_token( + logits_tensor, temperature); } protected: diff --git a/extension/llm/sampler/targets.bzl b/extension/llm/sampler/targets.bzl index 9b7751c19e7..b76bfcd6133 100644 --- a/extension/llm/sampler/targets.bzl +++ b/extension/llm/sampler/targets.bzl @@ -8,6 +8,7 @@ def define_common_targets(): name = "sampler" + aten_suffix, exported_headers = [ "sampler.h", + "util.h", ], preprocessor_flags = [ "-DUSE_ATEN_LIB", diff --git a/extension/llm/sampler/util.h b/extension/llm/sampler/util.h new file mode 100644 index 00000000000..6a3a06355ca --- /dev/null +++ b/extension/llm/sampler/util.h @@ -0,0 +1,65 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace executorch { +namespace extension { +namespace llm { + +/** + * Sample the next token from the logits tensor. + * @param logits_tensor The logits tensor. + * @param temperature The temperature parameter used to control randomness in + * sampling. + * @return The next token. + */ +inline int32_t logits_to_token( + const executorch::aten::Tensor& logits_tensor, + const float temperature = 0.0f) { + int32_t result = 0; + + // Create a minimal context for error handling in ET_SWITCH + struct { + [[noreturn]] void fail(torch::executor::Error /* error */) { + ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token"); + } + } ctx; + + ET_SWITCH_FOUR_TYPES( + Float, + Half, + BFloat16, + UInt16, + logits_tensor.scalar_type(), + ctx, + "logits_to_token", + CTYPE, + [&]() { + // If the logit_tensor rank is 3, the shape is [batch, seq_length, + // vocab_size], get the last logits, sample and return. Else the model + // outputs the last logit, directly sample and return. + auto* logits = logits_tensor.mutable_data_ptr(); + ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); + if (logits_tensor.dim() == 3) { + auto num_tokens = logits_tensor.size(1); + logits += (num_tokens - 1) * vocab_size; + } + // @lint-ignore CLANGTIDY facebook-hte-Deprecated + Sampler sampler(vocab_size, temperature); + result = sampler.sample(logits); + }); + return result; +} + +} // namespace llm +} // namespace extension +} // namespace executorch diff --git a/tools/cmake/executorch-config.cmake b/tools/cmake/executorch-config.cmake index 3f97db77ccc..d6f8ded668b 100644 --- a/tools/cmake/executorch-config.cmake +++ b/tools/cmake/executorch-config.cmake @@ -68,6 +68,7 @@ set(optional_lib_list qnn_executorch_backend portable_ops_lib custom_ops + extension_asr_runner extension_evalue_util extension_llm_runner extension_module diff --git a/tools/cmake/preset/default.cmake b/tools/cmake/preset/default.cmake index 861e41e4a63..0dcec0df531 100644 --- a/tools/cmake/preset/default.cmake +++ b/tools/cmake/preset/default.cmake @@ -65,6 +65,10 @@ define_overridable_option( EXECUTORCH_BUILD_KERNELS_QUANTIZED_AOT "Build the optimized ops library for AOT export usage" BOOL OFF ) +define_overridable_option( + EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER "Build the ASR runner extension" BOOL + OFF +) define_overridable_option( EXECUTORCH_BUILD_EXTENSION_DATA_LOADER "Build the Data Loader extension" BOOL ON # Required by executor_runner diff --git a/tools/cmake/preset/llm.cmake b/tools/cmake/preset/llm.cmake index 6cd2482f717..f3d3ab8ef8f 100644 --- a/tools/cmake/preset/llm.cmake +++ b/tools/cmake/preset/llm.cmake @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # keep sorted +set_overridable_option(EXECUTORCH_BUILD_EXTENSION_ASR_RUNNER ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_FLAT_TENSOR ON) set_overridable_option(EXECUTORCH_BUILD_EXTENSION_LLM ON)