-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add quantization support #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
maitrisavaliya
wants to merge
20
commits into
microsoft:main
Choose a base branch
from
maitrisavaliya:add-quantization-support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
45ba769
Add troubleshooting guide for common installation and usage issues
573d852
Add quantization support to reduce VRAM requirements
54b594b
Add quantization support to reduce VRAM requirements
0328c1e
Merge branch 'microsoft:main' into add-quantization-support
maitrisavaliya e3e4d69
Delete utils/quantization,py
maitrisavaliya cdde460
Update realtime_model_inference_from_file.py
maitrisavaliya 62565c4
Delete TROUBLESHOOTING.md
maitrisavaliya 276ad09
Update realtime_model_inference_from_file.py
maitrisavaliya 15ca0ac
Update realtime_model_inference_from_file.py
maitrisavaliya c2a5bbf
Update vram_utils.py
maitrisavaliya 8b0c2cf
Merge branch 'main' into add-quantization-support
maitrisavaliya 188ffce
Apply suggestion from @Copilot
maitrisavaliya a0918a3
Apply suggestion from @Copilot
maitrisavaliya 4d5140a
Apply suggestion from @Copilot
maitrisavaliya 0bf0f0d
Apply suggestion from @Copilot
maitrisavaliya 1b08105
Apply suggestion from @Copilot
maitrisavaliya ef5eaa2
Clarify help message for quantization option
maitrisavaliya a83b636
Refactor VRAM utility functions and logging
maitrisavaliya 6e97a77
Enhance CUDA device handling and add VRAM info
maitrisavaliya ea332ac
Fix device_map condition for model loading
maitrisavaliya File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,113 @@ | ||
| """Quantization utilities for VibeVoice models.""" | ||
|
|
||
| import logging | ||
| from typing import Optional | ||
| import torch | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def get_quantization_config(quantization: str = "fp16") -> Optional[dict]: | ||
| """ | ||
| Get quantization configuration for model loading. | ||
|
|
||
| Args: | ||
| quantization: Quantization level ("fp16", "8bit", or "4bit") | ||
|
|
||
| Returns: | ||
| dict: Quantization config for from_pretrained, or None for fp16 | ||
| """ | ||
| if quantization == "fp16" or quantization == "full": | ||
| return None | ||
|
|
||
| if quantization == "8bit": | ||
| try: | ||
| import bitsandbytes as bnb | ||
| logger.info("Using 8-bit quantization (selective LLM only)") | ||
| return { | ||
| "load_in_8bit": True, | ||
| "llm_int8_threshold": 6.0, | ||
| } | ||
| except ImportError: | ||
| logger.error( | ||
| "8-bit quantization requires bitsandbytes. " | ||
| "Install with: pip install bitsandbytes" | ||
| ) | ||
| raise | ||
|
|
||
| elif quantization == "4bit": | ||
| try: | ||
| import bitsandbytes as bnb | ||
| from transformers import BitsAndBytesConfig | ||
|
|
||
| logger.info("Using 4-bit NF4 quantization (selective LLM only)") | ||
| return { | ||
| "quantization_config": BitsAndBytesConfig( | ||
| load_in_4bit=True, | ||
| bnb_4bit_quant_type="nf4", | ||
| bnb_4bit_compute_dtype=torch.bfloat16, | ||
| bnb_4bit_use_double_quant=True, | ||
| ) | ||
| } | ||
| except ImportError: | ||
| logger.error( | ||
| "4-bit quantization requires bitsandbytes. " | ||
| "Install with: pip install bitsandbytes" | ||
| ) | ||
| raise | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"Invalid quantization: {quantization}. " | ||
| f"Must be one of: fp16, 8bit, 4bit" | ||
| ) | ||
|
|
||
|
|
||
| def apply_selective_quantization(model, quantization: str): | ||
| """ | ||
| Apply selective quantization only to safe components. | ||
|
|
||
| This function identifies which modules should be quantized and which | ||
| should remain at full precision for audio quality preservation. | ||
|
|
||
| Args: | ||
| model: The VibeVoice model | ||
| quantization: Quantization level ("8bit" or "4bit") | ||
| """ | ||
| if quantization == "fp16": | ||
| return model | ||
|
|
||
| logger.info("Applying selective quantization...") | ||
|
|
||
| # Components to KEEP at full precision (audio-critical) | ||
| # For the streaming model, audio-critical modules are typically exposed as | ||
| # a prediction head and acoustic_* components. We match on these names to | ||
| # ensure they remain at higher precision while only the LLM is quantized. | ||
| keep_fp_components = [ | ||
| "prediction_head", | ||
| "acoustic_", | ||
| ] | ||
|
|
||
| # Only quantize the LLM (Qwen2.5) component | ||
| quantize_components = ["llm", "language_model"] | ||
maitrisavaliya marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| quantize_components = ["llm", "language_model"] | ||
|
|
||
| for name, module in model.named_modules(): | ||
| # Check if this module should stay at full precision | ||
| should_keep_fp = any(comp in name for comp in keep_fp_components) | ||
| should_quantize = any(comp in name for comp in quantize_components) | ||
|
|
||
| if should_keep_fp: | ||
| # Ensure audio components stay at higher precision (e.g., bfloat16 instead of 4/8-bit) | ||
| with torch.no_grad(): | ||
| module.to(torch.bfloat16) | ||
| logger.debug(f"Keeping {name} at full precision (audio-critical)") | ||
|
|
||
| elif should_quantize: | ||
| logger.debug(f"Quantized {name} to {quantization}") | ||
|
|
||
| logger.info(f"✓ Selective {quantization} quantization applied") | ||
| logger.info(" • LLM: Quantized") | ||
| logger.info(" • Audio components: Full precision") | ||
|
|
||
| return model | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| """VRAM detection and quantization recommendation utilities.""" | ||
|
|
||
| import torch | ||
| from transformers.utils import logging | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| def get_available_vram_gb() -> float: | ||
| """ | ||
| Get available VRAM in GB. | ||
|
|
||
| Returns: | ||
| float: Available VRAM in GB, or 0 if no CUDA device available | ||
| """ | ||
| if not torch.cuda.is_available(): | ||
| return 0.0 | ||
|
|
||
| try: | ||
| # Get first CUDA device | ||
| device = torch.device("cuda:0") | ||
|
|
||
| # Prefer direct CUDA mem info if available (free, total in bytes) | ||
| if hasattr(torch.cuda, "mem_get_info"): | ||
| free_bytes, total_bytes = torch.cuda.mem_get_info(device) | ||
| available_gb = free_bytes / (1024 ** 3) | ||
| else: | ||
| # Fallback: estimate free memory from total minus reserved/allocated | ||
| props = torch.cuda.get_device_properties(device) | ||
| total_bytes = props.total_memory | ||
| reserved_bytes = torch.cuda.memory_reserved(device) | ||
| allocated_bytes = torch.cuda.memory_allocated(device) | ||
| used_bytes = max(reserved_bytes, allocated_bytes) | ||
| free_bytes = max(total_bytes - used_bytes, 0) | ||
| available_gb = free_bytes / (1024 ** 3) | ||
|
|
||
| return available_gb | ||
| except Exception as e: | ||
| logger.warning(f"Could not detect VRAM: {e}") | ||
| return 0.0 | ||
|
|
||
|
|
||
| def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str: | ||
| """ | ||
| Suggest quantization level based on available VRAM. | ||
|
|
||
| Args: | ||
| available_vram_gb: Available VRAM in GB | ||
| model_name: Name of the model being loaded | ||
|
|
||
| Returns: | ||
| str: Suggested quantization level ("fp16", "8bit", or "4bit") | ||
| """ | ||
| # Parse model size from name (e.g., "0.5B", "1.5B", "7B") | ||
| import re | ||
| size_match = re.search(r'(\d+\.?\d*)B', model_name) | ||
|
|
||
| if size_match: | ||
| size_b = float(size_match.group(1)) | ||
| else: | ||
| # Default to 7B if size cannot be determined | ||
| size_b = 7.0 | ||
|
|
||
| # Adjust thresholds based on model size | ||
| if size_b <= 0.5: | ||
| # 0.5B model | ||
| if available_vram_gb >= 4: | ||
| return "fp16" | ||
| elif available_vram_gb >= 3: | ||
| return "8bit" | ||
| else: | ||
| return "4bit" | ||
| elif size_b <= 1.5: | ||
| # 1.5B model | ||
| if available_vram_gb >= 8: | ||
| return "fp16" | ||
| elif available_vram_gb >= 6: | ||
| return "8bit" | ||
| else: | ||
| return "4bit" | ||
| else: | ||
| # 7B or larger model | ||
| if available_vram_gb >= 22: | ||
| return "fp16" | ||
| elif available_vram_gb >= 14: | ||
| return "8bit" | ||
| else: | ||
| return "4bit" | ||
|
|
||
|
|
||
| def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"): | ||
| """ | ||
| Print VRAM information and quantization recommendation. | ||
|
|
||
| Args: | ||
| available_vram_gb: Available VRAM in GB | ||
| model_name: Name of the model being loaded | ||
| quantization: Current quantization setting | ||
| """ | ||
| logger.info(f"Available VRAM: {available_vram_gb:.1f}GB") | ||
|
|
||
| suggested = suggest_quantization(available_vram_gb, model_name) | ||
|
|
||
| if suggested != quantization and quantization == "fp16": | ||
| logger.warning( | ||
| f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). " | ||
| f"Recommended: --quantization {suggested}" | ||
| ) | ||
| logger.warning( | ||
| f" Example: python demo/realtime_model_inference_from_file.py " | ||
| f"--model_path {model_name} --quantization {suggested} ..." | ||
| ) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.