Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 199 additions & 82 deletions models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,74 @@
"""

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
from transformers import (
AutoModelForSequenceClassification,
AutoConfig,
)
from torch.ao.quantization import QuantStub, DeQuantStub
from transformers import AutoConfig, AutoModelForSequenceClassification

try:
# Optional: used for true INT8 on GPUs (e.g., T4) via bitsandbytes
from transformers import BitsAndBytesConfig

_HAS_BITSANDBYTES = True
except Exception: # pragma: no cover - optional dependency
BitsAndBytesConfig = None
_HAS_BITSANDBYTES = False

from models.QuantWrapper import QuantDistilBertWrapper


import torch
import torch.nn as nn
from torch.ao.quantization import QuantStub, DeQuantStub

class QuantDistilBertWrapper(nn.Module):
"""
Wrap a DistilBERT classification model with Quant/DeQuant stubs
so we can use static PTQ (prepare/convert) like in Lab 3.
"""
def __init__(self, base_model: nn.Module):
super().__init__()
self.quant = QuantStub()
self.dequant = DeQuantStub()
self.model = base_model # usually AutoModelForSequenceClassification

def forward(self, input_ids, attention_mask=None, **kwargs):
"""
Quantization-aware forward that **does not** modify token IDs.

Token IDs are categorical indices and must remain integers for
correct embedding lookup. We therefore:
1. Compute embeddings from the original `input_ids`
2. Quantize the **embedding activations**
3. Run the rest of the model on quantized activations
4. Dequantize the final logits
"""
# 1) Standard embedding lookup using raw (non-quantized) token IDs
embeddings = self.model.distilbert.embeddings(input_ids)

# 2) Quantize the embedding activations instead of the token IDs
embeddings_q = self.quant(embeddings)

# 3) Run the HF classification model using quantized embeddings
outputs = self.model(
inputs_embeds=embeddings_q,
attention_mask=attention_mask,
**kwargs,
)

# 4) Dequantize logits before returning
logits = self.dequant(outputs.logits)

# Return an object with `.logits` so existing code keeps working
class OutputWrapper:
def __init__(self, logits_tensor):
self.logits = logits_tensor

return OutputWrapper(logits)



@dataclass
Expand All @@ -25,69 +85,68 @@ class ModelConfig:
num_labels: int = 2


def _apply_int8_quantization_cuda(model: nn.Module, verbose: bool = True) -> None:
@dataclass
class LayerwiseQuantConfig:
"""
Apply INT8-like quantization for CUDA by quantizing weights to INT8 range.

This uses symmetric per-tensor quantization:
- Quantize: Q = round(R / scale) where scale = max(abs(R)) / 127
- Dequantize: R' = Q * scale
Configuration for hybrid precision per layer.

The weights are stored as FP32/FP16 but quantized to INT8 precision.
This is not true INT8 compute (which requires special kernels) but
simulates the accuracy/precision loss of INT8 quantization.
`layer_precision` maps a substring (matching module names) to a
desired precision: "fp32", "fp16", or "int8".
- "int8" -> attach qconfig so the module is statically quantized
- "fp32"/"fp16" -> clear qconfig so the module stays in float

Args:
model: Model to quantize (must be on CUDA)
verbose: Whether to print quantization info
Example:
LayerwiseQuantConfig(
layer_precision={
"attention": "int8",
"ffn": "int8",
"classifier": "fp32",
}
)
"""
num_quantized = 0

for name, module in model.named_modules():
if isinstance(module, nn.Linear):
# Quantize weight
with torch.no_grad():
weight = module.weight.data

# Symmetric per-tensor quantization
# Scale = max_abs_value / 127 (INT8 max positive value)
scale = weight.abs().max() / 127.0
layer_precision: Dict[str, str]

if scale > 0:
# Quantize: divide by scale, round, clamp to INT8 range
weight_q = torch.round(weight / scale)
weight_q = torch.clamp(weight_q, -128, 127)

# Dequantize: multiply back by scale
weight_dequant = weight_q * scale

# Replace original weight with quantized version
module.weight.data = weight_dequant
def apply_layerwise_qconfig(
model: nn.Module,
qconfig: torch.ao.quantization.QConfig,
cfg: LayerwiseQuantConfig,
) -> None:
"""
Attach qconfig selectively to DistilBERT submodules.

# Quantize bias if it exists
if module.bias is not None:
bias = module.bias.data
scale_bias = bias.abs().max() / 127.0
This is a simple heuristic based on layer-name substrings. It lets
you implement hybrid schemes (e.g., INT8 attention + FP32 classifier).
"""
for name, module in model.named_modules():
# Skip explicit quant/dequant stubs themselves
if isinstance(module, (QuantStub, DeQuantStub)):
continue

if scale_bias > 0:
bias_q = torch.round(bias / scale_bias)
bias_q = torch.clamp(bias_q, -128, 127)
bias_dequant = bias_q * scale_bias
module.bias.data = bias_dequant
matched_precision: Optional[str] = None
for pattern, prec in cfg.layer_precision.items():
if pattern in name:
matched_precision = prec

num_quantized += 1
if matched_precision is None:
# No explicit rule -> leave whatever global qconfig is set
continue

if verbose:
print(f" ✓ Quantized {num_quantized} Linear layers to INT8 precision")
print(f" ✓ Running on CUDA (simulated INT8 compute)")
if matched_precision.lower() == "int8":
module.qconfig = qconfig
else:
# Any non-int8 precision means "keep this layer in float"
module.qconfig = None


def load_model(
model_name: str = "distilbert-base-uncased-finetuned-sst-2-english",
precision: str = "fp32",
device: str = "cuda",
num_labels: int = 2,
verbose: bool = True
verbose: bool = True,
layerwise_cfg: Optional[LayerwiseQuantConfig] = None,
) -> nn.Module:
"""
Load a transformer model in specified precision.
Expand Down Expand Up @@ -132,44 +191,102 @@ def load_model(
config = AutoConfig.from_pretrained(model_name)
config.num_labels = num_labels

# Load model in FP32 first
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
torch_dtype=torch.float32
)
# Special handling for true INT8 on CUDA GPUs (e.g., T4) using bitsandbytes
if precision == "int8" and device == "cuda":
if not _HAS_BITSANDBYTES:
raise RuntimeError(
"INT8 on CUDA requested, but bitsandbytes/transformers integration is not "
"available. Install with `pip install bitsandbytes` and a recent "
"`transformers` version."
)

if verbose:
print(" Using bitsandbytes INT8 on CUDA (Tensor Core friendly, e.g., T4)")

# Configure 8-bit loading; keeps weights in 8-bit and uses INT8 kernels
bnb_config = BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
)

# Apply precision conversion
if precision == "fp32":
model = model.to(device)
# Device map: put everything on GPU 0; the benchmark harness already
# assumes a single-GPU setup.
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
quantization_config=bnb_config,
device_map={"": 0},
)

elif precision == "fp16":
# Convert to FP16
model = model.half()
model = model.to(device)
else:
# Load model in FP32 first, then optionally cast/quantize
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
torch_dtype=torch.float32,
)

elif precision == "int8":
# For CUDA: Use manual INT8 conversion with fake quantization
# This simulates INT8 by quantizing weights to int8 range but keeps computation in FP32/FP16
if device == "cuda":
if verbose:
print(" Using CUDA-compatible INT8 (simulated quantization)")
# Apply precision conversion
if precision == "fp32":
model = model.to(device)

# Move model to CUDA first
elif precision == "fp16":
# Convert to FP16
model = model.half()
model = model.to(device)

# Apply simulated INT8 quantization to Linear layers
_apply_int8_quantization_cuda(model, verbose=verbose)
else:
# For CPU: Use PyTorch's dynamic quantization
if verbose:
print(" Using CPU dynamic quantization")
model = torch.quantization.quantize_dynamic(
model,
{nn.Linear},
dtype=torch.qint8
)
model = model.to("cpu")
elif precision == "int8":
# Two paths for INT8:
# - CPU: use static PTQ with QuantDistilBertWrapper, qconfig,
# prepare + calibration + convert (like Lab 3).
# - (CUDA handled above via bitsandbytes)
if device != "cuda":
# Static PTQ on CPU using wrapper + qconfig + prepare/convert
if verbose:
print(" Using CPU static PTQ with QuantDistilBertWrapper")

# Wrap the HF model with Quant/DeQuant stubs
wrapped = QuantDistilBertWrapper(model)

# Default to a global qconfig if none is provided
qconfig = torch.ao.quantization.get_default_qconfig("fbgemm")

if layerwise_cfg is not None:
# Attach qconfig selectively to attention / FFN / classifier
apply_layerwise_qconfig(wrapped, qconfig, layerwise_cfg)
else:
# Global qconfig (all eligible modules)
wrapped.qconfig = qconfig

if verbose:
print(" Preparing model for static quantization...")

prepared = torch.ao.quantization.prepare(wrapped, inplace=False)
prepared.eval()

# Lightweight calibration with random token IDs.
# This is intentionally simple; the main project harness
# still uses the high-quality FP32/FP16 paths on GPU.
vocab_size = getattr(config, "vocab_size", 30522)
seq_len = getattr(config, "max_position_embeddings", 128)

with torch.no_grad():
for _ in range(10):
dummy_ids = torch.randint(
low=0,
high=vocab_size,
size=(8, seq_len),
dtype=torch.long,
)
dummy_mask = torch.ones_like(dummy_ids)
_ = prepared(input_ids=dummy_ids, attention_mask=dummy_mask)

if verbose:
print(" Converting calibrated model to INT8...")

quantized = torch.ao.quantization.convert(prepared, inplace=False)
model = quantized.to("cpu")

model.eval()

Expand Down
2,702 changes: 2,701 additions & 1 deletion notebooks/energy_measurement_harness.ipynb

Large diffs are not rendered by default.