- Pre-train Llama-3.1-70B 1.5x faster with float8 training
- Recover 77% of quantized perplexity degradation on Llama-3.2-3B with QAT
- Quantize Llama-3-8B to int4 for 1.89x faster inference with 58% less memory
Latest News | Overview | Quick Start | Integrations | Inference | Training | Videos | Citation
- [Jun 25] Our TorchAO paper was accepted to CodeML @ ICML 2025!
- [Apr 25] Float8 rowwise training yielded 1.34-1.43x training speedup at 2k H100 GPU scale
- [Apr 25] TorchAO is added as a quantization backend to vLLM (docs)!
- [Mar 25] Our 2:4 Sparsity paper was accepted to SLLM @ ICLR 2025!
- [Jan 25] Our integration with GemLite and SGLang yielded 1.1-2x faster inference with int4 and float8 quantization across different batch sizes and tensor parallel sizes
- [Jan 25] We added 1-8 bit ARM CPU kernels for linear and embedding ops
Older news
- [Nov 24] We achieved 1.43-1.51x faster pre-training on Llama-3.1-70B and 405B using float8 training
- [Oct 24] TorchAO is added as a quantization backend to HF Transformers!
- [Sep 24] We officially launched TorchAO. Check out our blog here!
- [Jul 24] QAT recovered up to 96% accuracy degradation from quantization on Llama-3-8B
- [Jun 24] Semi-structured 2:4 sparsity achieved 1.1x inference speedup and 1.3x training speedup on the SAM and ViT models respectively
- [Jun 24] Block sparsity achieved 1.46x training speeedup on the ViT model with <2% drop in accuracy
TorchAO is a PyTorch-native model optimization framework leveraging quantization and sparsity to provide an end-to-end, training-to-serving workflow
for AI models. TorchAO works out-of-the-box with torch.compile()
and FSDP2
across most HuggingFace PyTorch models. Key features include:
- Float8 training and inference for speedups without compromising accuracy
- MX training and inference, provides MX tensor formats based on native PyTorch MX dtypes (prototype)
- Quantization-Aware Training (QAT) for mitigating quantization degradation
- Post-Training Quantization (PTQ) for int4, int8, fp6 etc, with matching kernels targeting a variety of backends including CUDA, ARM CPU, and XNNPACK
- Sparsity, includes different techniques such as 2:4 sparsity and block sparsity
Check out our docs for more details!
From the team that brought you the fast series:
- 9.5x inference speedups for Image segmentation models with sam-fast
- 10x inference speedups for Language models with gpt-fast
- 3x inference speedup for Diffusion models with sd-fast
- 2.7x inference speedup for FAIRβs Seamless M4T-v2 model with seamlessv2-fast
First, install TorchAO. We recommend installing the latest stable version:
pip install torchao
Other installation options
# Nightly
pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
# Different CUDA versions
pip install torchao --index-url https://download.pytorch.org/whl/cu126 # CUDA 12.6
pip install torchao --index-url https://download.pytorch.org/whl/cpu # CPU only
# For developers
USE_CUDA=1 python setup.py develop
Quantize your model weights to int4!
from torchao.quantization import Int4WeightOnlyConfig, quantize_
quantize_(model, Int4WeightOnlyConfig(group_size=32))
Compared to a torch.compiled
bf16 baseline, your quantized model should be significantly smaller and faster on a single A100 GPU:
int4 model size: 1.25 MB
bfloat16 model size: 4.00 MB
compression ratio: 3.2
bf16 mean time: 30.393 ms
int4 mean time: 4.410 ms
speedup: 6.9x
For the full model setup and benchmark details, check out our quick start guide. Alternatively, try quantizing your favorite model using our HuggingFace space!
TorchAO is integrated into some of the leading open-source libraries including:
- HuggingFace transformers with a builtin inference backend and low bit optimizers
- HuggingFace diffusers best practices with
torch.compile
and TorchAO in a standalone repo diffusers-torchao - Mobius HQQ backend leveraged our int4 kernels to get 195 tok/s on a 4090
- TorchTune for our QLoRA, QAT, and float8 quantized fine-tuning recipes
- TorchTitan for float8 pre-training
- VLLM for LLM serving: usage, detailed docs
- SGLang for LLM serving: usage and the major PR.
- Axolotl for QAT and PTQ
TorchAO delivers substantial performance gains with minimal code changes:
- Int4 weight-only: 1.89x throughput with 58.1% less memory on Llama-3-8B
- Float8 dynamic quantization: 1.54x and 1.27x speedup on Flux.1-Dev* and CogVideoX-5b respectively on H100 with preserved quality
- Int4 + 2:4 Sparsity: 2.37x throughput with 67.7% memory reduction on Llama-3-8B
Quantize any model with nn.Linear
layers in just one line (Option 1), or load the quantized model directly from HuggingFace using our integration with HuggingFace transformers (Option 2):
from torchao.quantization.quant_api import quantize_, Int4WeightOnlyConfig
quantize_(model, Int4WeightOnlyConfig(group_size=128, use_hqq=True))
from transformers import TorchAoConfig, AutoModelForCausalLM
from torchao.quantization.quant_api import Int4WeightOnlyConfig
# Create quantization configuration
quantization_config = TorchAoConfig(quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True))
# Load and automatically quantize
quantized_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-4-mini-instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
vllm serve pytorch/Phi-4-mini-instruct-int4wo-hqq --tokenizer microsoft/Phi-4-mini-instruct -O3
With this quantization flow, we achieve 67% VRAM reduction and 12-20% speedup on A100 GPUs while maintaining model quality. For more detail, see this step-by-step quantization guide. We also release some pre-quantized models here.
Post-training quantization can result in a fast and compact model, but may also lead to accuracy degradation. We recommend exploring Quantization-Aware Training (QAT) to overcome this limitation, especially for lower bit-width dtypes such as int4. In collaboration with TorchTune, we've developed a QAT recipe that demonstrates significant accuracy improvements over traditional PTQ, recovering 96% of the accuracy degradation on hellaswag and 68% of the perplexity degradation on wikitext for Llama3 compared to post-training quantization (PTQ). For more details, please refer to the QAT README and the original blog:
from torchao.quantization import quantize_
from torchao.quantization.qat import FakeQuantizeConfig, IntXQuantizationAwareTrainingConfig
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
qat_config = IntXQuantizationAwareTrainingConfig(activation_config, weight_config),
quantize_(my_model, qat_config)
Users can also combine LoRA + QAT to speed up training by 1.89x compared to vanilla QAT using this fine-tuning recipe.
torchao.float8 implements training recipes with the scaled float8 dtypes, as laid out in https://arxiv.org/abs/2209.05433. With torch.compile
on, current results show throughput speedups of up to 1.5x on up to 512 GPU / 405B parameter count scale (details):
from torchao.float8 import convert_to_float8_training
convert_to_float8_training(m)
Our float8 training is integrated into TorchTitan's pre-training flows so users can easily try it out. For more details, check out these blog posts about our float8 training support:
- Accelerating Large Scale Training and Convergence with PyTorch Float8 Rowwise on Crusoe 2K H200s
- Supercharging Training using float8 and FSDP2
- Efficient Pre-training of Llama 3-like model architectures using torchtitan on Amazon SageMaker
- Float8 in PyTorch
We've added support for semi-structured 2:4 sparsity with 6% end-to-end speedups on ViT-L. Full blog here. The code change is a 1 liner with the full example available here:
from torchao.sparsity.training import SemiSparseLinear, swap_linear_with_semi_sparse_linear
swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear})
Optimizers like ADAM can consume substantial GPU memory - 2x as much as the model parameters themselves. TorchAO provides two approaches to reduce this overhead:
1. Quantized optimizers: Reduce optimizer state memory by 2-4x by quantizing to lower precision
from torchao.optim import AdamW8bit, AdamW4bit, AdamWFp8
optim = AdamW8bit(model.parameters()) # replace with Adam4bit and AdamFp8 for the 4 / fp8 versions
Our quantized optimizers are implemented in just a few hundred lines of PyTorch code and compiled for efficiency. While slightly slower than specialized kernels, they offer an excellent balance of memory savings and performance. See detailed benchmarks here.
2. CPU offloading: Move optimizer state and gradients to CPU memory
For maximum memory savings, we support single GPU CPU offloading that efficiently moves both gradients and optimizer state to CPU memory. This approach can reduce your VRAM requirements by 60% with minimal impact on training speed:
optim = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, fused=True)
optim.load_state_dict(ckpt["optim"])
- Keynote talk at GPU MODE IRL
- Low precision dtypes at PyTorch conference
- Slaying OOMs at the Mastering LLM's course
- Advanced Quantization at CUDA MODE
- Chip Huyen's GPU Optimization Workshop
- Cohere for AI community talk
If you find the torchao library useful, please cite it in your work as below.
@software{torchao,
title={TorchAO: PyTorch-Native Training-to-Serving Model Optimization},
author={torchao},
url={https://github.com/pytorch/torchao},
license={BSD-3-Clause},
month={oct},
year={2024}
}