Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
45ba769
Add troubleshooting guide for common installation and usage issues
Dec 9, 2025
573d852
Add quantization support to reduce VRAM requirements
Dec 10, 2025
54b594b
Add quantization support to reduce VRAM requirements
Dec 10, 2025
0328c1e
Merge branch 'microsoft:main' into add-quantization-support
maitrisavaliya Dec 10, 2025
e3e4d69
Delete utils/quantization,py
maitrisavaliya Dec 10, 2025
cdde460
Update realtime_model_inference_from_file.py
maitrisavaliya Dec 10, 2025
62565c4
Delete TROUBLESHOOTING.md
maitrisavaliya Dec 10, 2025
276ad09
Update realtime_model_inference_from_file.py
maitrisavaliya Dec 10, 2025
15ca0ac
Update realtime_model_inference_from_file.py
maitrisavaliya Dec 10, 2025
c2a5bbf
Update vram_utils.py
maitrisavaliya Dec 10, 2025
8b0c2cf
Merge branch 'main' into add-quantization-support
maitrisavaliya Dec 17, 2025
188ffce
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
a0918a3
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
4d5140a
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
0bf0f0d
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
1b08105
Apply suggestion from @Copilot
maitrisavaliya Feb 13, 2026
ef5eaa2
Clarify help message for quantization option
maitrisavaliya Feb 13, 2026
a83b636
Refactor VRAM utility functions and logging
maitrisavaliya Feb 13, 2026
6e97a77
Enhance CUDA device handling and add VRAM info
maitrisavaliya Feb 13, 2026
ea332ac
Fix device_map condition for model loading
maitrisavaliya Feb 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions demo/realtime_model_inference_from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor
from transformers.utils import logging

from utils.vram_utils import get_available_vram_gb, print_vram_info
from utils.quantization import get_quantization_config, apply_selective_quantization

logging.set_verbosity_info()
logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -121,6 +124,13 @@ def parse_args():
default=1.5,
help="CFG (Classifier-Free Guidance) scale for generation (default: 1.5)",
)
parser.add_argument(
"--quantization",
type=str,
default="fp16",
choices=["fp16", "8bit", "4bit"],
help="Quantization level: fp16 (no quantization, ~20GB), 8bit (~12GB), or 4bit (~7GB)"
)

return parser.parse_args()

Expand All @@ -138,6 +148,14 @@ def main():
args.device = "cpu"

print(f"Using device: {args.device}")

# VRAM Detection and Quantization Info (NEW)
if args.device.startswith("cuda"):
available_vram = get_available_vram_gb()
print_vram_info(available_vram, args.model_path, args.quantization)
elif args.quantization != "fp16":
print(f"Warning: Quantization ({args.quantization}) only works with CUDA. Using full precision.")
args.quantization = "fp16"

# Initialize voice mapper
voice_mapper = VoiceMapper()
Expand Down Expand Up @@ -172,6 +190,15 @@ def main():
load_dtype = torch.float32
attn_impl_primary = "sdpa"
print(f"Using device: {args.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}")

# Get quantization configuration (NEW)
quant_config = get_quantization_config(args.quantization)

if quant_config:
print(f"Using {args.quantization} quantization...")
else:
print("Using full precision (no quantization)...")

# Load model with device-specific logic
try:
if args.device == "mps":
Expand All @@ -182,13 +209,26 @@ def main():
device_map=None, # load then move
)
model.to("mps")
elif args.device == "cuda":
elif args.device.startswith("cuda"):
# MODIFIED SECTION - Add quantization support
model_kwargs = {
"torch_dtype": load_dtype,
"device_map": "cuda",
"attn_implementation": attn_impl_primary,
}

# Add quantization config if specified
if quant_config:
model_kwargs.update(quant_config)

model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map="cuda",
attn_implementation=attn_impl_primary,
**model_kwargs
)

# Apply selective quantization if needed (NEW)
if args.quantization in ["8bit", "4bit"]:
model = apply_selective_quantization(model, args.quantization)
else: # cpu
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
Expand All @@ -204,7 +244,7 @@ def main():
model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained(
args.model_path,
torch_dtype=load_dtype,
device_map=(args.device if args.device in ("cuda", "cpu") else None),
device_map=(args.device if args.device.startswith("cuda") or args.device == "cpu" else None),
attn_implementation='sdpa'
)
if args.device == "mps":
Expand Down
Empty file added utils/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions utils/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""Quantization utilities for VibeVoice models."""

import logging
from typing import Optional
import torch

logger = logging.getLogger(__name__)


def get_quantization_config(quantization: str = "fp16") -> Optional[dict]:
"""
Get quantization configuration for model loading.

Args:
quantization: Quantization level ("fp16", "8bit", or "4bit")

Returns:
dict: Quantization config for from_pretrained, or None for fp16
"""
if quantization == "fp16" or quantization == "full":
return None

if quantization == "8bit":
try:
import bitsandbytes as bnb
logger.info("Using 8-bit quantization (selective LLM only)")
return {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
}
except ImportError:
logger.error(
"8-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

elif quantization == "4bit":
try:
import bitsandbytes as bnb
from transformers import BitsAndBytesConfig

logger.info("Using 4-bit NF4 quantization (selective LLM only)")
return {
"quantization_config": BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
}
except ImportError:
logger.error(
"4-bit quantization requires bitsandbytes. "
"Install with: pip install bitsandbytes"
)
raise

else:
raise ValueError(
f"Invalid quantization: {quantization}. "
f"Must be one of: fp16, 8bit, 4bit"
)


def apply_selective_quantization(model, quantization: str):
"""
Apply selective quantization only to safe components.

This function identifies which modules should be quantized and which
should remain at full precision for audio quality preservation.

Args:
model: The VibeVoice model
quantization: Quantization level ("8bit" or "4bit")
"""
if quantization == "fp16":
return model

logger.info("Applying selective quantization...")

# Components to KEEP at full precision (audio-critical)
# For the streaming model, audio-critical modules are typically exposed as
# a prediction head and acoustic_* components. We match on these names to
# ensure they remain at higher precision while only the LLM is quantized.
keep_fp_components = [
"prediction_head",
"acoustic_",
]

# Only quantize the LLM (Qwen2.5) component
quantize_components = ["llm", "language_model"]
quantize_components = ["llm", "language_model"]

for name, module in model.named_modules():
# Check if this module should stay at full precision
should_keep_fp = any(comp in name for comp in keep_fp_components)
should_quantize = any(comp in name for comp in quantize_components)

if should_keep_fp:
# Ensure audio components stay at higher precision (e.g., bfloat16 instead of 4/8-bit)
with torch.no_grad():
module.to(torch.bfloat16)
logger.debug(f"Keeping {name} at full precision (audio-critical)")

elif should_quantize:
logger.debug(f"Quantized {name} to {quantization}")

logger.info(f"✓ Selective {quantization} quantization applied")
logger.info(" • LLM: Quantized")
logger.info(" • Audio components: Full precision")

return model
112 changes: 112 additions & 0 deletions utils/vram_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""VRAM detection and quantization recommendation utilities."""

import torch
from transformers.utils import logging

logger = logging.get_logger(__name__)


def get_available_vram_gb() -> float:
"""
Get available VRAM in GB.

Returns:
float: Available VRAM in GB, or 0 if no CUDA device available
"""
if not torch.cuda.is_available():
return 0.0

try:
# Get first CUDA device
device = torch.device("cuda:0")

# Prefer direct CUDA mem info if available (free, total in bytes)
if hasattr(torch.cuda, "mem_get_info"):
free_bytes, total_bytes = torch.cuda.mem_get_info(device)
available_gb = free_bytes / (1024 ** 3)
else:
# Fallback: estimate free memory from total minus reserved/allocated
props = torch.cuda.get_device_properties(device)
total_bytes = props.total_memory
reserved_bytes = torch.cuda.memory_reserved(device)
allocated_bytes = torch.cuda.memory_allocated(device)
used_bytes = max(reserved_bytes, allocated_bytes)
free_bytes = max(total_bytes - used_bytes, 0)
available_gb = free_bytes / (1024 ** 3)

return available_gb
except Exception as e:
logger.warning(f"Could not detect VRAM: {e}")
return 0.0


def suggest_quantization(available_vram_gb: float, model_name: str = "VibeVoice-7B") -> str:
"""
Suggest quantization level based on available VRAM.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded

Returns:
str: Suggested quantization level ("fp16", "8bit", or "4bit")
"""
# Parse model size from name (e.g., "0.5B", "1.5B", "7B")
import re
size_match = re.search(r'(\d+\.?\d*)B', model_name)

if size_match:
size_b = float(size_match.group(1))
else:
# Default to 7B if size cannot be determined
size_b = 7.0

# Adjust thresholds based on model size
if size_b <= 0.5:
# 0.5B model
if available_vram_gb >= 4:
return "fp16"
elif available_vram_gb >= 3:
return "8bit"
else:
return "4bit"
elif size_b <= 1.5:
# 1.5B model
if available_vram_gb >= 8:
return "fp16"
elif available_vram_gb >= 6:
return "8bit"
else:
return "4bit"
else:
# 7B or larger model
if available_vram_gb >= 22:
return "fp16"
elif available_vram_gb >= 14:
return "8bit"
else:
return "4bit"


def print_vram_info(available_vram_gb: float, model_name: str, quantization: str = "fp16"):
"""
Print VRAM information and quantization recommendation.

Args:
available_vram_gb: Available VRAM in GB
model_name: Name of the model being loaded
quantization: Current quantization setting
"""
logger.info(f"Available VRAM: {available_vram_gb:.1f}GB")

suggested = suggest_quantization(available_vram_gb, model_name)

if suggested != quantization and quantization == "fp16":
logger.warning(
f"⚠️ Low VRAM detected ({available_vram_gb:.1f}GB). "
f"Recommended: --quantization {suggested}"
)
logger.warning(
f" Example: python demo/realtime_model_inference_from_file.py "
f"--model_path {model_name} --quantization {suggested} ..."
)