Skip to content

[GenAI] Use BitsAndBytes for 4bit quantization. #7406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static async Task RunLlama(string weightFolder, string checkPointName = "
var stopWatch = System.Diagnostics.Stopwatch.StartNew();
stopWatch.Start();
var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true);
var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeTo4Bit: true);

var pipeline = new CausalLMPipeline<TiktokenTokenizer, LlamaForCausalLM>(tokenizer, model, device);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="TorchSharp-cuda-windows" Version="0.102.5" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />
<PackageReference Include="TorchSharp-cuda-windows" Version="0.105.0" Condition="$([MSBuild]::IsOSPlatform('Windows'))" />
<PackageReference Include="Microsoft.SemanticKernel" Version="$(SemanticKernelVersion)" />
<PackageReference Include="AutoGen.SourceGenerator" Version="$(AutoGenVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Console" Version="8.0.0" />
<PackageReference Include="LittleLittleCloud.TorchSharp.BitsAndBytes" Version="$(TorchSharpBitsAndBytesVersion)" />
</ItemGroup>

</Project>
4 changes: 2 additions & 2 deletions docs/samples/Microsoft.ML.GenAI.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// See https://aka.ms/new-console-template for more information
using Microsoft.ML.GenAI.Samples.Llama;
using Microsoft.ML.GenAI.Samples.MEAI;

await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct");
//await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors");
//await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct");
5 changes: 3 additions & 2 deletions eng/Versions.props
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@
<TorchSharpPyBridgeVersion>1.4.1</TorchSharpPyBridgeVersion>
<AutoGenVersion>0.1.0</AutoGenVersion>
<SemanticKernelVersion>1.15.0</SemanticKernelVersion>
<TorchSharpVersion>0.102.7</TorchSharpVersion>
<LibTorchVersion>2.2.1.1</LibTorchVersion>
<TorchSharpVersion>0.105.0</TorchSharpVersion>
<LibTorchVersion>2.5.1</LibTorchVersion>
<TorchSharpBitsAndBytesVersion>0.0.4</TorchSharpBitsAndBytesVersion>
<!-- Build/infrastructure Dependencies -->
<CodecovVersion>1.12.4</CodecovVersion>
<CoverletCollectorVersion>6.0.2</CoverletCollectorVersion>
Expand Down
15 changes: 10 additions & 5 deletions src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,18 @@ public static void ToInt8QuantizeModule<T>(
/// </summary>
/// <typeparam name="T"></typeparam>
/// <param name="model"></param>
public static void ToInt4QuantizeModule<T>(
this T model)
/// <param name="quantizedDType">Quantized data type, can be "fp4" or "nf4". "fp4" means 4-bits floating point (1-bit sign, 2-bit exponent and 1-bit mantissa) and "nf4" means normalized 4-bits floating point, which uses a specialized non-uniform quantization aimed at neural network weight distributions (ranged from -1 to 1).</param>
/// <param name="blockSize">Block size for quantization, can be [64, 128, 256, 512, 1024]. The larger the size, the faster the speed and the lower the precision.</param>
public static void ToQuantize4BitModule<T>(
this T model,
string quantizedDType = "fp4",
int blockSize = 64)
where T : nn.Module
{
var config = new Quantize4BitConfig(quantizedDType, blockSize);
if (model is IQuantizeModule quantized)
{
quantized.Int4();
quantized.Quantize4Bit(config);

return;
}
Expand All @@ -105,11 +110,11 @@ public static void ToInt4QuantizeModule<T>(
{
if (value is IQuantizeModule quantizeModule)
{
quantizeModule.Int4();
quantizeModule.Quantize4Bit(config);
}
else
{
value.ToInt4QuantizeModule();
value.ToQuantize4BitModule(quantizedDType, blockSize);
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
<IsPackable>true</IsPackable>
<Nullable>enable</Nullable>
<LangVersion>preview</LangVersion>
<NoWarn>$(NoWarn);CS8002</NoWarn>
</PropertyGroup>

<PropertyGroup Condition="'$(TargetFramework)' == 'net8.0'">
Expand All @@ -17,6 +18,7 @@
<PackageReference Include="Microsoft.SemanticKernel.Abstractions" Version="$(SemanticKernelVersion)" />
<PackageReference Include="System.Memory" Version="$(SystemMemoryVersion)" />
<PackageReference Include="TorchSharp" Version="$(TorchSharpVersion)" />
<PackageReference Include="LittleLittleCloud.TorchSharp.BitsAndBytes" Version="$(TorchSharpBitsAndBytesVersion)" />
</ItemGroup>

<ItemGroup>
Expand Down
30 changes: 28 additions & 2 deletions src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Licensed to the .NET Foundation under one or more agreements.
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

Expand All @@ -8,5 +8,31 @@ public interface IQuantizeModule
{
public void Int8();

public void Int4();
/// <summary>
/// Quantize using BitsAndBytes.FP4
/// </summary>
/// <param name="config"><see cref="Quantize4BitConfig"/></param>
public void Quantize4Bit(Quantize4BitConfig config);
}

/// <summary>
/// Quantize configuration for 4-bit quantization.
/// </summary>
public record Quantize4BitConfig
{
public Quantize4BitConfig(string quantizedDType = "fp4", int blockSize = 64)
{
QuantizedDType = quantizedDType;
BlockSize = blockSize;
}

/// <summary>
/// Quantized data type, can be "fp4" or "nf4".
/// </summary>
public string QuantizedDType { get; init; }

/// <summary>
/// Block size for quantization, can be [64, 128, 256, 512, 1024]. The larger the size, the faster the speed and the lower the precision.
/// </summary>
public int BlockSize { get; init; }
}
152 changes: 66 additions & 86 deletions src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,24 @@
using System;
using Microsoft.ML.GenAI.Core;
using TorchSharp;
using TorchSharp.BitsAndBytes;
using TorchSharp.Modules;
using static TorchSharp.torch;
namespace Microsoft.ML.GenAI.Core;

internal class QuantizedLinear : GenAILinear, IQuantizeModule
{
private Tensor? _quantizedTensor = null;
private Tensor? _absMax = null;
private int _blockSize;
private int _n;
private string? _quantizedDType = null;
private readonly long[] _weightShape;

public QuantizedLinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null)
: base(inFeatures, outFeatures, hasBias, dtype, device)
{
_weightShape = [outFeatures, inFeatures];
}

public void Int8()
Expand Down Expand Up @@ -79,6 +89,7 @@ public void Int8()
public override Tensor forward(Tensor input)
#pragma warning restore MSML_GeneralName // This name should be PascalCased
{
var inputShape = input.shape;
if (this._internal_buffers.ContainsKey("weight"))
{
return base.forward(input);
Expand All @@ -87,9 +98,9 @@ public override Tensor forward(Tensor input)
{
// 8bit quantization
using var dispose = torch.NewDisposeScope();
var weight = this.get_buffer("8bit_weight").to(ScalarType.Float32);
var zeroPoint = this.get_buffer("zeroPoint").to(ScalarType.Float32);
var scale = this.get_buffer("scale").to(ScalarType.Float32);
var weight = this.get_buffer("8bit_weight")!.to(ScalarType.Float32);
var zeroPoint = this.get_buffer("zeroPoint")!.to(ScalarType.Float32);
var scale = this.get_buffer("scale")!.to(ScalarType.Float32);
var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
// use float32
var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);
Expand All @@ -102,108 +113,77 @@ public override Tensor forward(Tensor input)
//result.Peek("result");
return result.to_type(input.dtype).MoveToOuterDisposeScope();
}
else if (this._internal_buffers.ContainsKey("4bit_weight"))
else if ((_quantizedDType == "fp4" || _quantizedDType == "nf4") && _quantizedTensor is not null && _absMax is not null)
{
using var dispose = torch.NewDisposeScope();
var weight = this.get_buffer("4bit_weight");
var weightLower = weight % 16;
var weightUpper = weight / 16;
weight = torch.cat([weightUpper, weightLower], 0).to(ScalarType.Float32);
weight = weight.view(this._outFeatures, this._inFeatures);
weight -= 8;
var zeroPoint = this.get_buffer("zeroPoint");
var zeroPointLower = zeroPoint % 16;
var zeroPointUpper = zeroPoint / 16;
zeroPoint = torch.cat([zeroPointUpper, zeroPointLower], 0).to(ScalarType.Float32);
zeroPoint -= 8;
var scale = this.get_buffer("scale").to(ScalarType.Float32);
var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1);
// use float32
var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T);

if (this.bias is not null)
if (input.shape.Length >= 3 && input.shape[1] != 1)
{
result = result + this.bias.to_type(ScalarType.Float32);
// dequantize quantizedWeight to float32 and use torch.matmul
var dequantizedWeight = BitsAndByteUtils.Dequantize4Bit(
tensor: this._quantizedTensor,
originalDType: input.dtype,
originalShape: this._weightShape,
blockSize: _blockSize,
n: this._n,
absMax: this._absMax!,
quantizedDType: _quantizedDType);

var output = torch.matmul(input, dequantizedWeight.T);

if (this.bias is not null)
{
output = output.add_(this.bias.to_type(output.dtype));
}

return output.MoveToOuterDisposeScope();
}
else
{
var output = BitsAndByteUtils.Gemv4Bit(
input: input,
quantizedWeight: this._quantizedTensor,
originalWeightShape: _weightShape,
absMax: this._absMax!,
quantizedDType: _quantizedDType,
blockSize: _blockSize);

if (this.bias is not null)
{
output = output.add_(this.bias.to_type(output.dtype));
}

return output.MoveToOuterDisposeScope();
}

//result.Peek("result");
return result.to_type(input.dtype).MoveToOuterDisposeScope();
}
else
{
throw new Exception("Quantization is not done yet");
}
}

public void Int4()
public void Quantize4Bit(Quantize4BitConfig config)
{
if (this.weight is null)
{
throw new Exception("Weight is not initialized");
}
var placeHolderDim = this._outFeatures / 2 + this._outFeatures % 2;
var fourBitWeightDim = this.weight.size(0) * this.weight.size(1);
var fourBitWeightPlaceHolderDim = Convert.ToInt32(fourBitWeightDim / 2 + fourBitWeightDim % 2);
if (this.weight.device_type != DeviceType.META)
{
using var scope = NewDisposeScope();
var timer = new System.Diagnostics.Stopwatch();
timer.Start();
// scale and zero point on vector-wise
// scale = 15 / max(weight, axis=1) - min(weight, axis=1)
var scale = 15 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values);

// zero point = - scale * min(weight, axis=1) - 8
var zeroPoint = -scale * torch.min(this.weight, 1).values - 8;
// round zero point to nearest integer
zeroPoint = torch.round(zeroPoint);
var fourBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8);

zeroPoint = (zeroPoint + 8).to(torch.uint8);
fourBitWeight = (fourBitWeight + 8).view(-1).to(torch.uint8);

// torch doesn't provide int4, so we use int8 as placeholder
// and foreach int8, we save two int4, e.g. 0b1010 -> 0b10, 0b10
var zpPlaceHolder = zeroPoint[..placeHolderDim];
zpPlaceHolder = zpPlaceHolder * 16 + zeroPoint[placeHolderDim..];

// assert zero point is in range [-128, 127]
//if (torch.any(this.zeroPoint < -128).item<bool>() || torch.any(this.zeroPoint > 127).item<bool>())
//{
// throw new Exception("Zero point is out of range [-128, 127]");
//}

// quantize weight
var fourBitWeightPlaceHolder = fourBitWeight[..fourBitWeightPlaceHolderDim];
fourBitWeightPlaceHolder = fourBitWeightPlaceHolder * 16 + fourBitWeight[fourBitWeightPlaceHolderDim..];
if (this.weight.device_type == DeviceType.META)
{
return;
}
using var dispose = torch.NewDisposeScope();

// assert weight is in range [-128, 127]
//if (torch.any(this._8bitWeight < -128).item<bool>() || torch.any(this._8bitWeight > 127).item<bool>())
//{
// throw new Exception("Weight is out of range [-128, 127]");
//}
_quantizedDType = config.QuantizedDType; // Available options: "fp4", "nf4"
_blockSize = config.BlockSize; // can be [64, 128, 256, 512, 1024]

// dispose float32 weight
this.weight.Dispose();
// Quantize to 4Bit
(_quantizedTensor, _absMax, _blockSize, _n) = BitsAndByteUtils.Quantize4Bit(this.weight.cuda(), _quantizedDType, _blockSize);

this._internal_buffers.Remove("weight");
this.register_buffer("4bit_weight", fourBitWeightPlaceHolder.MoveToOuterDisposeScope());
this.register_buffer("zeroPoint", zpPlaceHolder.MoveToOuterDisposeScope());
this.register_buffer("scale", scale.MoveToOuterDisposeScope());
timer.Stop();
}
else
{
// if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale
var fourBitWeight = torch.zeros(fourBitWeightPlaceHolderDim, dtype: torch.int8);
var zeroPoint = torch.zeros(placeHolderDim, dtype: torch.int8);
var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32);

this._internal_buffers.Remove("weight");
this.weight = null;
this.register_buffer("4bit_weight", fourBitWeight);
this.register_buffer("zeroPoint", zeroPoint);
this.register_buffer("scale", scale);
}
this.weight.Dispose();
this.weight = null;
this._internal_buffers.Remove("weight");
_quantizedTensor.MoveToOuterDisposeScope();
_absMax.MoveToOuterDisposeScope();
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ public override RotaryEmbeddingOutput forward(RotaryEmbeddingInput input)
var seqLen = input.SeqLen;
// TODO
// can be calculated once and cached
var invFreq = this.get_buffer("inv_freq").to(x.device);
var invFreq = this.get_buffer("inv_freq")!.to(x.device);
var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1);
invFreqExpanded = invFreqExpanded.expand(new long[] { positionIds.shape[0], -1, 1 });
var positionIdsExpanded = positionIds.unsqueeze(1).to(torch.float32);
Expand Down
12 changes: 6 additions & 6 deletions src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,12 @@ public static LlamaForCausalLM FromPretrained(
string configName = "config.json",
string checkPointName = "model.safetensors.index.json",
bool quantizeToInt8 = false,
bool quantizeToInt4 = false,
bool quantizeTo4Bit = false,
int layersOnTargetDevice = -1,
ScalarType torchDtype = ScalarType.BFloat16,
string targetDevice = "cuda")
{
if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false)
if (layersOnTargetDevice == -1 && quantizeTo4Bit == false && quantizeToInt8 == false)
{
return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice);
}
Expand All @@ -141,9 +141,9 @@ public static LlamaForCausalLM FromPretrained(
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
else if (quantizeTo4Bit)
{
model.ToInt4QuantizeModule();
model.ToQuantize4BitModule();
}

var deviceMap = model.InferDeviceMapForEachLayer(
Expand All @@ -161,9 +161,9 @@ public static LlamaForCausalLM FromPretrained(
{
model.ToInt8QuantizeModule();
}
else if (quantizeToInt4)
else if (quantizeTo4Bit)
{
model.ToInt4QuantizeModule();
model.ToQuantize4BitModule();
}

model = model.ToDynamicLoadingModel(deviceMap, targetDevice);
Expand Down
Loading
Loading