diff --git a/demo/realtime_model_inference_from_file.py b/demo/realtime_model_inference_from_file.py index a321f6c4..4dbf0fb6 100644 --- a/demo/realtime_model_inference_from_file.py +++ b/demo/realtime_model_inference_from_file.py @@ -12,6 +12,9 @@ from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor from transformers.utils import logging +from utils.vram_utils import get_available_vram_gb, print_vram_info +from utils.quantization import get_quantization_config, apply_selective_quantization + logging.set_verbosity_info() logger = logging.get_logger(__name__) @@ -121,6 +124,13 @@ def parse_args(): default=1.5, help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)", ) + parser.add_argument( + "--quantization", + type=str, + default="fp16", + choices=["fp16", "8bit", "4bit"], + help="Quantization level: fp16 (no quantization, ~20GB), 8bit (~12GB), or 4bit (~7GB)" + ) return parser.parse_args() @@ -138,6 +148,14 @@ def main(): args.device = "cpu" print(f"Using device: {args.device}") + + # VRAM Detection and Quantization Info (NEW) + if args.device.startswith("cuda"): + available_vram = get_available_vram_gb() + print_vram_info(available_vram, args.model_path, args.quantization) + elif args.quantization != "fp16": + print(f"Warning: Quantization ({args.quantization}) only works with CUDA. Using full precision.") + args.quantization = "fp16" # Initialize voice mapper voice_mapper = VoiceMapper() @@ -172,6 +190,15 @@ def main(): load_dtype = torch.float32 attn_impl_primary = "sdpa" print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}") + + # Get quantization configuration (NEW) + quant_config = get_quantization_config(args.quantization) + + if quant_config: + print(f"Using {args.quantization} quantization...") + else: + print("Using full precision (no quantization)...") + # Load model with device-specific logic try: if args.device == "mps": @@ -182,13 +209,26 @@ def main(): device_map=None, # load then move ) model.to("mps") - elif args.device == "cuda": + elif args.device.startswith("cuda"): + # MODIFIED SECTION - Add quantization support + model_kwargs = { + "torch_dtype": load_dtype, + "device_map": "cuda", + "attn_implementation": attn_impl_primary, + } + + # Add quantization config if specified + if quant_config: + model_kwargs.update(quant_config) + model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( args.model_path, - torch_dtype=load_dtype, - device_map="cuda", - attn_implementation=attn_impl_primary, + **model_kwargs ) + + # Apply selective quantization if needed (NEW) + if args.quantization in ["8bit", "4bit"]: + model = apply_selective_quantization(model, args.quantization) else: # cpu model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( args.model_path, @@ -204,7 +244,7 @@ def main(): model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( args.model_path, torch_dtype=load_dtype, - device_map=(args.device if args.device in ("cuda", "cpu") else None), + device_map=(args.device if args.device.startswith("cuda") or args.device == "cpu" else None), attn_implementation='sdpa' ) if args.device == "mps": diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/utils/quantization.py b/utils/quantization.py new file mode 100644 index 00000000..e32afdcb --- /dev/null +++ b/utils/quantization.py @@ -0,0 +1,113 @@ +"""Quantization utilities for VibeVoice models.""" + +import logging +from typing import Optional +import torch + +logger = logging.getLogger(__name__) + + +def get_quantization_config(quantization: str = "fp16") -> Optional[dict]: + """ + Get quantization configuration for model loading. + + Args: + quantization: Quantization level ("fp16", "8bit", or "4bit") + + Returns: + dict: Quantization config for from_pretrained, or None for fp16 + """ + if quantization == "fp16" or quantization == "full": + return None + + if quantization == "8bit": + try: + import bitsandbytes as bnb + logger.info("Using 8-bit quantization (selective LLM only)") + return { + "load_in_8bit": True, + "llm_int8_threshold": 6.0, + } + except ImportError: + logger.error( + "8-bit quantization requires bitsandbytes. " + "Install with: pip install bitsandbytes" + ) + raise + + elif quantization == "4bit": + try: + import bitsandbytes as bnb + from transformers import BitsAndBytesConfig + + logger.info("Using 4-bit NF4 quantization (selective LLM only)") + return { + "quantization_config": BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_use_double_quant=True, + ) + } + except ImportError: + logger.error( + "4-bit quantization requires bitsandbytes. " + "Install with: pip install bitsandbytes" + ) + raise + + else: + raise ValueError( + f"Invalid quantization: {quantization}. " + f"Must be one of: fp16, 8bit, 4bit" + ) + + +def apply_selective_quantization(model, quantization: str): + """ + Apply selective quantization only to safe components. + + This function identifies which modules should be quantized and which + should remain at full precision for audio quality preservation. + + Args: + model: The VibeVoice model + quantization: Quantization level ("8bit" or "4bit") + """ + if quantization == "fp16": + return model + + logger.info("Applying selective quantization...") + + # Components to KEEP at full precision (audio-critical) + # For the streaming model, audio-critical modules are typically exposed as + # a prediction head and acoustic_* components. We match on these names to + # ensure they remain at higher precision while only the LLM is quantized. + keep_fp_components = [ + "prediction_head", + "acoustic_", + ] + + # Only quantize the LLM (Qwen2.5) component + quantize_components = ["llm", "language_model"] + quantize_components = ["llm", "language_model"] + + for name, module in model.named_modules(): + # Check if this module should stay at full precision + should_keep_fp = any(comp in name for comp in keep_fp_components) + should_quantize = any(comp in name for comp in quantize_components) + + if should_keep_fp: + # Ensure audio components stay at higher precision (e.g., bfloat16 instead of 4/8-bit) + with torch.no_grad(): + module.to(torch.bfloat16) + logger.debug(f"Keeping {name} at full precision (audio-critical)") + + elif should_quantize: + logger.debug(f"Quantized {name} to {quantization}") + + logger.info(f"✓ Selective {quantization} quantization applied") + logger.info(" • LLM: Quantized") + logger.info(" • Audio components: Full precision") + + return model \ No newline at end of file diff --git a/utils/vram_utils.py b/utils/vram_utils.py new file mode 100644 index 00000000..234aa710 --- /dev/null +++ b/utils/vram_utils.py @@ -0,0 +1,112 @@ +"""VRAM detection and quantization recommendation utilities.""" + +import torch +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +def get_available_vram_gb() -> float: + """ + Get available VRAM in GB. + + Returns: + float: Available VRAM in GB, or 0 if no CUDA device available + """ + if not torch.cuda.is_available(): + return 0.0 + + try: + # Get first CUDA device + device = torch.device("cuda:0") + + # Prefer direct CUDA mem info if available (free, total in bytes) + if hasattr(torch.cuda, "mem_get_info"): + free_bytes, total_bytes = torch.cuda.mem_get_info(device) + available_gb = free_bytes / (1024 ** 3) + else: + # Fallback: estimate free memory from total minus reserved/allocated + props = torch.cuda.get_device_properties(device) + total_bytes = props.total_memory + reserved_bytes = torch.cuda.memory_reserved(device) + allocated_bytes = torch.cuda.memory_allocated(device) + used_bytes = max(reserved_bytes, allocated_bytes) + free_bytes = max(total_bytes - used_bytes, 0) + available_gb = free_bytes / (1024 ** 3) + + return available_gb + except Exception as e: + logger.warning(f"Could not detect VRAM: {e}") + return 0.0 + + +def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str: + """ + Suggest quantization level based on available VRAM. + + Args: + available_vram_gb: Available VRAM in GB + model_name: Name of the model being loaded + + Returns: + str: Suggested quantization level ("fp16", "8bit", or "4bit") + """ + # Parse model size from name (e.g., "0.5B", "1.5B", "7B") + import re + size_match = re.search(r'(\d+\.?\d*)B', model_name) + + if size_match: + size_b = float(size_match.group(1)) + else: + # Default to 7B if size cannot be determined + size_b = 7.0 + + # Adjust thresholds based on model size + if size_b <= 0.5: + # 0.5B model + if available_vram_gb >= 4: + return "fp16" + elif available_vram_gb >= 3: + return "8bit" + else: + return "4bit" + elif size_b <= 1.5: + # 1.5B model + if available_vram_gb >= 8: + return "fp16" + elif available_vram_gb >= 6: + return "8bit" + else: + return "4bit" + else: + # 7B or larger model + if available_vram_gb >= 22: + return "fp16" + elif available_vram_gb >= 14: + return "8bit" + else: + return "4bit" + + +def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"): + """ + Print VRAM information and quantization recommendation. + + Args: + available_vram_gb: Available VRAM in GB + model_name: Name of the model being loaded + quantization: Current quantization setting + """ + logger.info(f"Available VRAM: {available_vram_gb:.1f}GB") + + suggested = suggest_quantization(available_vram_gb, model_name) + + if suggested != quantization and quantization == "fp16": + logger.warning( + f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). " + f"Recommended: --quantization {suggested}" + ) + logger.warning( + f" Example: python demo/realtime_model_inference_from_file.py " + f"--model_path {model_name} --quantization {suggested} ..." + )