From a67fc3a5b8c0cdde4f2c32e9f69ace9073889567 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 1 Mar 2025 16:37:07 -0800 Subject: [PATCH 1/7] use bitsandbytes for 4bit quantize --- .../Llama/LlamaSample.cs | 2 +- .../Microsoft.ML.GenAI.Samples.csproj | 3 +- .../Microsoft.ML.GenAI.Samples/Program.cs | 4 +- eng/Versions.props | 4 +- .../Extension/ModuleExtension.cs | 4 +- .../Microsoft.ML.GenAI.Core.csproj | 2 + .../Module/IQuantizeModule.cs | 5 +- .../Module/QuantizedLinear.cs | 152 ++++++++---------- .../Module/RotaryEmbedding.cs | 2 +- .../Module/Phi2RotaryEmbedding.cs | 2 +- .../Phi2/Phi2ForCasualLM.cs | 1 + .../QuantizedLinearTests.cs | 2 +- 12 files changed, 85 insertions(+), 98 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs index 97248ed272..bdc81fc768 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs @@ -34,7 +34,7 @@ public static async Task RunLlama(string weightFolder, string checkPointName = " var stopWatch = System.Diagnostics.Stopwatch.StartNew(); stopWatch.Start(); var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder); - var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: 26, quantizeToInt8: true); + var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt4: true); var pipeline = new CausalLMPipeline(tokenizer, model, device); diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index c8cee633ac..6996640f19 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -16,10 +16,11 @@ - + + diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs index 6f4d809948..b435100d3e 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Program.cs @@ -1,6 +1,6 @@ // See https://aka.ms/new-console-template for more information using Microsoft.ML.GenAI.Samples.Llama; using Microsoft.ML.GenAI.Samples.MEAI; - -await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); +await LlamaSample.RunLlama(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-3B-Instruct"); +//await SFT_Llama_3_2_1B.Train(@"C:\Users\xiaoyuz\source\repos\Llama-3.2-1B-Instruct", checkPointName: "model.safetensors"); //await Phi3.RunAsync(@"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct"); diff --git a/eng/Versions.props b/eng/Versions.props index 522e5f3489..4a738f610f 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -73,8 +73,8 @@ 1.4.1 0.1.0 1.15.0 - 0.102.7 - 2.2.1.1 + 0.105.0 + 2.5.1 1.12.4 6.0.2 diff --git a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs index a904c394b9..0598a8ce07 100644 --- a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs +++ b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs @@ -96,7 +96,7 @@ public static void ToInt4QuantizeModule( { if (model is IQuantizeModule quantized) { - quantized.Int4(); + quantized.FP4(); return; } @@ -105,7 +105,7 @@ public static void ToInt4QuantizeModule( { if (value is IQuantizeModule quantizeModule) { - quantizeModule.Int4(); + quantizeModule.FP4(); } else { diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 9ed5a2702e..233507bcb2 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -5,6 +5,7 @@ true enable preview + $(NoWarn);CS8002 @@ -17,6 +18,7 @@ + diff --git a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs index 57c0b7620f..b2a8936779 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs @@ -8,5 +8,8 @@ public interface IQuantizeModule { public void Int8(); - public void Int4(); + /// + /// Quantize using BitsAndBytes.FP4 + /// + public void FP4(); } diff --git a/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs index f399efe324..67b24f339f 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs @@ -4,14 +4,24 @@ using System; using Microsoft.ML.GenAI.Core; using TorchSharp; +using TorchSharp.BitsAndBytes; +using TorchSharp.Modules; using static TorchSharp.torch; namespace Microsoft.ML.GenAI.Core; internal class QuantizedLinear : GenAILinear, IQuantizeModule { + private Tensor? _quantizedTensor = null; + private Tensor? _absMax = null; + private int _blockSize; + private int _n; + private string? _quantizedDType = null; + private readonly long[] _weightShape; + public QuantizedLinear(int inFeatures, int outFeatures, bool hasBias = true, ScalarType dtype = ScalarType.Float32, string? device = null) : base(inFeatures, outFeatures, hasBias, dtype, device) { + _weightShape = [outFeatures, inFeatures]; } public void Int8() @@ -79,6 +89,7 @@ public void Int8() public override Tensor forward(Tensor input) #pragma warning restore MSML_GeneralName // This name should be PascalCased { + var inputShape = input.shape; if (this._internal_buffers.ContainsKey("weight")) { return base.forward(input); @@ -87,9 +98,9 @@ public override Tensor forward(Tensor input) { // 8bit quantization using var dispose = torch.NewDisposeScope(); - var weight = this.get_buffer("8bit_weight").to(ScalarType.Float32); - var zeroPoint = this.get_buffer("zeroPoint").to(ScalarType.Float32); - var scale = this.get_buffer("scale").to(ScalarType.Float32); + var weight = this.get_buffer("8bit_weight")!.to(ScalarType.Float32); + var zeroPoint = this.get_buffer("zeroPoint")!.to(ScalarType.Float32); + var scale = this.get_buffer("scale")!.to(ScalarType.Float32); var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1); // use float32 var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T); @@ -102,32 +113,47 @@ public override Tensor forward(Tensor input) //result.Peek("result"); return result.to_type(input.dtype).MoveToOuterDisposeScope(); } - else if (this._internal_buffers.ContainsKey("4bit_weight")) + else if ((_quantizedDType == "fp4" || _quantizedDType == "nf4") && _quantizedTensor is not null && _absMax is not null) { using var dispose = torch.NewDisposeScope(); - var weight = this.get_buffer("4bit_weight"); - var weightLower = weight % 16; - var weightUpper = weight / 16; - weight = torch.cat([weightUpper, weightLower], 0).to(ScalarType.Float32); - weight = weight.view(this._outFeatures, this._inFeatures); - weight -= 8; - var zeroPoint = this.get_buffer("zeroPoint"); - var zeroPointLower = zeroPoint % 16; - var zeroPointUpper = zeroPoint / 16; - zeroPoint = torch.cat([zeroPointUpper, zeroPointLower], 0).to(ScalarType.Float32); - zeroPoint -= 8; - var scale = this.get_buffer("scale").to(ScalarType.Float32); - var restoreWeight = (weight - zeroPoint.view(-1, 1)) / scale.view(-1, 1); - // use float32 - var result = torch.matmul(input.to(ScalarType.Float32), restoreWeight.T); - - if (this.bias is not null) + if (input.shape.Length >= 3 && input.shape[1] != 1) { - result = result + this.bias.to_type(ScalarType.Float32); + // dequantize quantizedWeight to float32 and use torch.matmul + var dequantizedWeight = BitsAndByteUtils.Dequantize4Bit( + tensor: this._quantizedTensor, + originalDType: input.dtype, + originalShape: this._weightShape, + blockSize: _blockSize, + n: this._n, + absMax: this._absMax!, + quantizedDType: _quantizedDType); + + var output = torch.matmul(input, dequantizedWeight.T); + + if (this.bias is not null) + { + output = output.add_(this.bias.to_type(output.dtype)); + } + + return output.MoveToOuterDisposeScope(); + } + else + { + var output = BitsAndByteUtils.Gemv4Bit( + input: input, + quantizedWeight: this._quantizedTensor, + originalWeightShape: _weightShape, + absMax: this._absMax!, + quantizedDType: _quantizedDType, + blockSize: _blockSize); + + if (this.bias is not null) + { + output = output.add_(this.bias.to_type(output.dtype)); + } + + return output.MoveToOuterDisposeScope(); } - - //result.Peek("result"); - return result.to_type(input.dtype).MoveToOuterDisposeScope(); } else { @@ -135,75 +161,29 @@ public override Tensor forward(Tensor input) } } - public void Int4() + public void FP4() { if (this.weight is null) { throw new Exception("Weight is not initialized"); } - var placeHolderDim = this._outFeatures / 2 + this._outFeatures % 2; - var fourBitWeightDim = this.weight.size(0) * this.weight.size(1); - var fourBitWeightPlaceHolderDim = Convert.ToInt32(fourBitWeightDim / 2 + fourBitWeightDim % 2); - if (this.weight.device_type != DeviceType.META) - { - using var scope = NewDisposeScope(); - var timer = new System.Diagnostics.Stopwatch(); - timer.Start(); - // scale and zero point on vector-wise - // scale = 15 / max(weight, axis=1) - min(weight, axis=1) - var scale = 15 / (torch.max(this.weight, 1).values - torch.min(this.weight, 1).values); - - // zero point = - scale * min(weight, axis=1) - 8 - var zeroPoint = -scale * torch.min(this.weight, 1).values - 8; - // round zero point to nearest integer - zeroPoint = torch.round(zeroPoint); - var fourBitWeight = torch.round(this.weight * scale.view(-1, 1) + zeroPoint.view(-1, 1)).to(torch.int8); - - zeroPoint = (zeroPoint + 8).to(torch.uint8); - fourBitWeight = (fourBitWeight + 8).view(-1).to(torch.uint8); - - // torch doesn't provide int4, so we use int8 as placeholder - // and foreach int8, we save two int4, e.g. 0b1010 -> 0b10, 0b10 - var zpPlaceHolder = zeroPoint[..placeHolderDim]; - zpPlaceHolder = zpPlaceHolder * 16 + zeroPoint[placeHolderDim..]; - - // assert zero point is in range [-128, 127] - //if (torch.any(this.zeroPoint < -128).item() || torch.any(this.zeroPoint > 127).item()) - //{ - // throw new Exception("Zero point is out of range [-128, 127]"); - //} - // quantize weight - var fourBitWeightPlaceHolder = fourBitWeight[..fourBitWeightPlaceHolderDim]; - fourBitWeightPlaceHolder = fourBitWeightPlaceHolder * 16 + fourBitWeight[fourBitWeightPlaceHolderDim..]; + if (this.weight.device_type == DeviceType.META) + { + return; + } + using var dispose = torch.NewDisposeScope(); - // assert weight is in range [-128, 127] - //if (torch.any(this._8bitWeight < -128).item() || torch.any(this._8bitWeight > 127).item()) - //{ - // throw new Exception("Weight is out of range [-128, 127]"); - //} + _quantizedDType = "fp4"; // Available options: "fp4", "nf4" + _blockSize = 64; // can be [64, 128, 256, 512, 1024] - // dispose float32 weight - this.weight.Dispose(); + // Quantize to 4Bit + (_quantizedTensor, _absMax, _blockSize, _n) = BitsAndByteUtils.Quantize4Bit(this.weight.cuda(), _quantizedDType, _blockSize); - this._internal_buffers.Remove("weight"); - this.register_buffer("4bit_weight", fourBitWeightPlaceHolder.MoveToOuterDisposeScope()); - this.register_buffer("zeroPoint", zpPlaceHolder.MoveToOuterDisposeScope()); - this.register_buffer("scale", scale.MoveToOuterDisposeScope()); - timer.Stop(); - } - else - { - // if weight is on meta device, then we just need to create the placeholder for 8bit_weight, zeroPoint and scale - var fourBitWeight = torch.zeros(fourBitWeightPlaceHolderDim, dtype: torch.int8); - var zeroPoint = torch.zeros(placeHolderDim, dtype: torch.int8); - var scale = torch.zeros(this.weight.shape[0], dtype: torch.float32); - - this._internal_buffers.Remove("weight"); - this.weight = null; - this.register_buffer("4bit_weight", fourBitWeight); - this.register_buffer("zeroPoint", zeroPoint); - this.register_buffer("scale", scale); - } + this.weight.Dispose(); + this.weight = null; + this._internal_buffers.Remove("weight"); + _quantizedTensor.MoveToOuterDisposeScope(); + _absMax.MoveToOuterDisposeScope(); } } diff --git a/src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs index 8e06c838d5..fa86164a6a 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/RotaryEmbedding.cs @@ -109,7 +109,7 @@ public override RotaryEmbeddingOutput forward(RotaryEmbeddingInput input) var seqLen = input.SeqLen; // TODO // can be calculated once and cached - var invFreq = this.get_buffer("inv_freq").to(x.device); + var invFreq = this.get_buffer("inv_freq")!.to(x.device); var invFreqExpanded = invFreq.unsqueeze(0).unsqueeze(-1); invFreqExpanded = invFreqExpanded.expand(new long[] { positionIds.shape[0], -1, 1 }); var positionIdsExpanded = positionIds.unsqueeze(1).to(torch.float32); diff --git a/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs b/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs index a21ed4959e..69a9115eb1 100644 --- a/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs +++ b/src/Microsoft.ML.GenAI.Phi/Module/Phi2RotaryEmbedding.cs @@ -36,7 +36,7 @@ public override (Tensor, Tensor) forward(Tensor x, int seqLen) { // TODO // can be calculated once and cached - var invFreq = this.get_buffer("inv_freq").to(x.device); + var invFreq = this.get_buffer("inv_freq")!.to(x.device); var t = torch.arange(seqLen, dtype: invFreq.dtype, device: invFreq.device); var freqs = torch.outer(t, invFreq).to(torch.float32); var emb = torch.cat([freqs, freqs], dim: -1); diff --git a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs index 1d49375565..2f2b037655 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi2/Phi2ForCasualLM.cs @@ -55,6 +55,7 @@ public static Phi2ForCasualLM FromPretrained( bool useTqdm = false, string? device = null) { + device ??= torch.get_default_device().ToString(); var config = Path.Join(modelFolder, configName); var modelConfig = JsonSerializer.Deserialize(File.ReadAllText(config)) ?? throw new ArgumentNullException(nameof(config)); modelConfig.Dtype = torchDtype; diff --git a/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs index d1653721ba..98ec3f697a 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs +++ b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs @@ -27,7 +27,7 @@ public void Int4QuantizeSizeTests() sizeInGigaBytes.Should().Be(38); // to int4 - model.Int4(); + model.FP4(); var sizeInBytesAfterInt8 = model.GetSizeInBytes(); var sizeInGigaBytesAfterInt8 = sizeInBytesAfterInt8 / 1024 / 1024; sizeInGigaBytesAfterInt8.Should().Be(4); // 38 // 8 = 4 From 389e11877b70552a8a153cb161662a88e7be2921 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 1 Mar 2025 17:00:14 -0800 Subject: [PATCH 2/7] update --- .../Extension/ModuleExtension.cs | 15 +++-- .../Module/IQuantizeModule.cs | 25 +++++++- .../Module/QuantizedLinear.cs | 6 +- .../LlamaForCausalLM.cs | 4 +- .../MistralForCausalLM.cs | 4 +- .../Phi3/Phi3ForCasualLM.cs | 2 +- .../QuantizedLinearTests.cs | 63 ------------------- .../Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs | 14 ----- 8 files changed, 42 insertions(+), 91 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs index 0598a8ce07..a9785b74b6 100644 --- a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs +++ b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs @@ -90,13 +90,18 @@ public static void ToInt8QuantizeModule( /// /// /// - public static void ToInt4QuantizeModule( - this T model) + /// Quantized data type, can be "fp4" or "nf4". + /// Block size for quantization, can be [64, 128, 256, 512, 1024]. The larger the size, the faster the speed and the lower the precision. + public static void ToQuantize4BitModule( + this T model, + string quantizedDType = "nf4", + int blockSize = 512) where T : nn.Module { + var config = new Quantize4BitConfig(quantizedDType, blockSize); if (model is IQuantizeModule quantized) { - quantized.FP4(); + quantized.Quantize4Bit(config); return; } @@ -105,11 +110,11 @@ public static void ToInt4QuantizeModule( { if (value is IQuantizeModule quantizeModule) { - quantizeModule.FP4(); + quantizeModule.Quantize4Bit(config); } else { - value.ToInt4QuantizeModule(); + value.ToQuantize4BitModule(quantizedDType, blockSize); } } } diff --git a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs index b2a8936779..6bdf6f89df 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs @@ -11,5 +11,28 @@ public interface IQuantizeModule /// /// Quantize using BitsAndBytes.FP4 /// - public void FP4(); + /// "/> + public void Quantize4Bit(Quantize4BitConfig config); +} + +/// +/// Quantize configuration for 4-bit quantization. +/// +public record Quantize4BitConfig +{ + public Quantize4BitConfig(string quantizedDType = "fp4", int blockSize = 1024) + { + QuantizedDType = quantizedDType; + BlockSize = blockSize; + } + + /// + /// Quantized data type, can be "fp4" or "nf4". + /// + public string QuantizedDType { get; init; } + + /// + /// Block size for quantization, can be [64, 128, 256, 512, 1024]. The larger the size, the faster the speed and the lower the precision. + /// + public int BlockSize { get; init; } } diff --git a/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs index 67b24f339f..e2dd95d529 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/QuantizedLinear.cs @@ -161,7 +161,7 @@ public override Tensor forward(Tensor input) } } - public void FP4() + public void Quantize4Bit(Quantize4BitConfig config) { if (this.weight is null) { @@ -174,8 +174,8 @@ public void FP4() } using var dispose = torch.NewDisposeScope(); - _quantizedDType = "fp4"; // Available options: "fp4", "nf4" - _blockSize = 64; // can be [64, 128, 256, 512, 1024] + _quantizedDType = config.QuantizedDType; // Available options: "fp4", "nf4" + _blockSize = config.BlockSize; // can be [64, 128, 256, 512, 1024] // Quantize to 4Bit (_quantizedTensor, _absMax, _blockSize, _n) = BitsAndByteUtils.Quantize4Bit(this.weight.cuda(), _quantizedDType, _blockSize); diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs index 0a6cdc8498..dcf43d3629 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs @@ -143,7 +143,7 @@ public static LlamaForCausalLM FromPretrained( } else if (quantizeToInt4) { - model.ToInt4QuantizeModule(); + model.ToQuantize4BitModule(); } var deviceMap = model.InferDeviceMapForEachLayer( @@ -163,7 +163,7 @@ public static LlamaForCausalLM FromPretrained( } else if (quantizeToInt4) { - model.ToInt4QuantizeModule(); + model.ToQuantize4BitModule(); } model = model.ToDynamicLoadingModel(deviceMap, targetDevice); diff --git a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs index 18d43e5317..d86757c3a5 100644 --- a/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs +++ b/src/Microsoft.ML.GenAI.Mistral/MistralForCausalLM.cs @@ -93,7 +93,7 @@ public static MistralForCausalLM FromPretrained( } else if (quantizeToInt4) { - model.ToInt4QuantizeModule(); + model.ToQuantize4BitModule(); } var deviceMap = model.InferDeviceMapForEachLayer( @@ -113,7 +113,7 @@ public static MistralForCausalLM FromPretrained( } else if (quantizeToInt4) { - model.ToInt4QuantizeModule(); + model.ToQuantize4BitModule(); } model = model.ToDynamicLoadingModel(deviceMap, targetDevice); diff --git a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs index a5840b242a..cb08145ed6 100644 --- a/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs +++ b/src/Microsoft.ML.GenAI.Phi/Phi3/Phi3ForCasualLM.cs @@ -95,7 +95,7 @@ public static Phi3ForCasualLM FromPretrained( } else if (quantizeToInt4) { - model.ToInt4QuantizeModule(); + model.ToQuantize4BitModule(); } var deviceMap = model.InferDeviceMapForEachLayer( diff --git a/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs index 98ec3f697a..ac74d63719 100644 --- a/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs +++ b/test/Microsoft.ML.GenAI.Core.Tests/QuantizedLinearTests.cs @@ -12,26 +12,6 @@ namespace Microsoft.ML.GenAI.Core.Tests; public class QuantizedLinearTests { - [Fact] - public void Int4QuantizeSizeTests() - { - // meta is critical for the test - // as the size of the model to test is 372 GB - // and can't be loaded in real device like cpu or cuda - var device = "meta"; - var model = new QuantizedLinear(100000, 100, device: device); - - var sizeInBytes = model.GetSizeInBytes(); - - var sizeInGigaBytes = sizeInBytes / 1024 / 1024; - sizeInGigaBytes.Should().Be(38); - - // to int4 - model.FP4(); - var sizeInBytesAfterInt8 = model.GetSizeInBytes(); - var sizeInGigaBytesAfterInt8 = sizeInBytesAfterInt8 / 1024 / 1024; - sizeInGigaBytesAfterInt8.Should().Be(4); // 38 // 8 = 4 - } [Fact] public void Int8QuantizeSizeTests() @@ -54,49 +34,6 @@ public void Int8QuantizeSizeTests() sizeInGigaBytesAfterInt8.Should().Be(9); // 38 // 4 = 9 } - [Fact] - public void Int4QuantizeForwardTest() - { - var device = "cpu"; - var model = new QuantizedLinear(123, 10, device: device); - - // set both weight and bias to rand int8 values - // and compare the result before and after ToInt8 - var input = torch.ones([10, 2200, 123], device: device); - var weight = torch.ones([10, 123], device: device, dtype: ScalarType.Int64) * -1; - var bias = torch.ones([10], device: device) * 2; - - var weightStr = weight.Peek("weight").ToString(); - - weight = (weight + 8).view(-1).to(torch.uint8); - var weightPlaceHolderDim = (int)weight.size(0); - weightPlaceHolderDim = weightPlaceHolderDim / 2 + weightPlaceHolderDim % 2; - var weightPlaceHolder = weight[..weightPlaceHolderDim]; - weightPlaceHolder = weightPlaceHolder * 16 + weight[weightPlaceHolderDim..]; - - var high4Bit = weightPlaceHolder / 16; - var low4Bit = weightPlaceHolder % 16; - weight = torch.cat(new Tensor[] { high4Bit, low4Bit }).view(10, 123); - weight = weight.to(torch.int64); - weight -= 8; - weight.Peek("weight").Should().Be(weightStr); - - model.load_state_dict(new Dictionary - { - ["weight"] = weight, - ["bias"] = bias - }); - - var resultBeforeInt4 = model.forward(input); - - model.ToInt4QuantizeModule(); - - var resultAfterInt4 = model.forward(input); - - // compare the result - resultBeforeInt4.Peek("result").Should().Be(resultAfterInt4.Peek("result")); - } - [Fact] public void Int8QuantizeForwardTest() { diff --git a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs index 1200d79f9d..7a5a763262 100644 --- a/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs +++ b/test/Microsoft.ML.GenAI.Phi.Tests/Phi3Tests.cs @@ -51,20 +51,6 @@ public void Phi3Mini4KInt8QuantizeShapeTest() Approvals.Verify(stateDictStr); } - [Fact] - [UseReporter(typeof(DiffReporter))] - [UseApprovalSubdirectory("Approvals")] - public void Phi3Mini4KInt4QuantizeShapeTest() - { - var model = new Phi3ForCasualLM(Phi3Config.Phi3Mini4kInstruct); - model.ToInt4QuantizeModule(); - var size = model.GetSizeInBytes(); - var stateDictStr = model.PeekShape(); - var sizeInGB = size / 1024 / 1024 / 1024; - sizeInGB.Should().Be(2); - Approvals.Verify(stateDictStr); - } - [Fact] [UseReporter(typeof(DiffReporter))] [UseApprovalSubdirectory("Approvals")] From f4f15ea2ebc9a483dac495821591548adbddb2c5 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Sat, 1 Mar 2025 17:00:43 -0800 Subject: [PATCH 3/7] update default configuration --- src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs index a9785b74b6..d52bc9e66c 100644 --- a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs +++ b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs @@ -94,8 +94,8 @@ public static void ToInt8QuantizeModule( /// Block size for quantization, can be [64, 128, 256, 512, 1024]. The larger the size, the faster the speed and the lower the precision. public static void ToQuantize4BitModule( this T model, - string quantizedDType = "nf4", - int blockSize = 512) + string quantizedDType = "fp4", + int blockSize = 64) where T : nn.Module { var config = new Quantize4BitConfig(quantizedDType, blockSize); From e52cfd67094516e64e95075cc666ec4c03e4439d Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Sat, 1 Mar 2025 17:07:50 -0800 Subject: [PATCH 4/7] Update src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs index 6bdf6f89df..a756b43b47 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs @@ -11,7 +11,7 @@ public interface IQuantizeModule /// /// Quantize using BitsAndBytes.FP4 /// - /// "/> + /// public void Quantize4Bit(Quantize4BitConfig config); } From cd7a6c5b9cbeedb4ba806cfa47a4b5406ca49ee0 Mon Sep 17 00:00:00 2001 From: Xiaoyun Zhang Date: Sat, 1 Mar 2025 17:08:46 -0800 Subject: [PATCH 5/7] Update default block size in Quantize4BitConfig --- src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs index a756b43b47..c6052e840c 100644 --- a/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs +++ b/src/Microsoft.ML.GenAI.Core/Module/IQuantizeModule.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. @@ -20,7 +20,7 @@ public interface IQuantizeModule /// public record Quantize4BitConfig { - public Quantize4BitConfig(string quantizedDType = "fp4", int blockSize = 1024) + public Quantize4BitConfig(string quantizedDType = "fp4", int blockSize = 64) { QuantizedDType = quantizedDType; BlockSize = blockSize; From 080cf48403d421e6be7e291f17665cd63a4a37e1 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Fri, 7 Mar 2025 12:48:33 -0800 Subject: [PATCH 6/7] kernelSize -> kernel_size --- src/Microsoft.ML.TorchSharp/AutoFormerV2/ConvModule.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ConvModule.cs b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ConvModule.cs index 8568239fc5..cfb9cf48da 100644 --- a/src/Microsoft.ML.TorchSharp/AutoFormerV2/ConvModule.cs +++ b/src/Microsoft.ML.TorchSharp/AutoFormerV2/ConvModule.cs @@ -35,7 +35,7 @@ public class ConvModule : Module public ConvModule(int inChannel, int outChannel, int kernelSize, int stride = 1, int padding = 0, int dilation = 1, bool bias = true, bool useRelu = true) : base(nameof(ConvModule)) { - this.conv = nn.Conv2d(in_channels: inChannel, out_channels: outChannel, kernelSize: kernelSize, stride: stride, padding: padding, dilation: dilation, bias: bias); + this.conv = nn.Conv2d(in_channels: inChannel, out_channels: outChannel, kernel_size: kernelSize, stride: stride, padding: padding, dilation: dilation, bias: bias); this.useRelu = useRelu; if (this.useRelu) { From 849a6c8436dcb9f9700148bb35128c1db3eaa5f6 Mon Sep 17 00:00:00 2001 From: XiaoYun Zhang Date: Tue, 11 Mar 2025 13:40:32 -0700 Subject: [PATCH 7/7] fix comments --- .../Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs | 2 +- .../Microsoft.ML.GenAI.Samples.csproj | 2 +- eng/Versions.props | 1 + src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs | 2 +- .../Microsoft.ML.GenAI.Core.csproj | 2 +- src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs | 8 ++++---- 6 files changed, 9 insertions(+), 8 deletions(-) diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs index bdc81fc768..140475c61e 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Llama/LlamaSample.cs @@ -34,7 +34,7 @@ public static async Task RunLlama(string weightFolder, string checkPointName = " var stopWatch = System.Diagnostics.Stopwatch.StartNew(); stopWatch.Start(); var tokenizer = LlamaTokenizerHelper.FromPretrained(originalWeightFolder); - var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeToInt4: true); + var model = LlamaForCausalLM.FromPretrained(weightFolder, configName, checkPointName: checkPointName, layersOnTargetDevice: -1, quantizeTo4Bit: true); var pipeline = new CausalLMPipeline(tokenizer, model, device); diff --git a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj index 6996640f19..bc348e044a 100644 --- a/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj +++ b/docs/samples/Microsoft.ML.GenAI.Samples/Microsoft.ML.GenAI.Samples.csproj @@ -20,7 +20,7 @@ - + diff --git a/eng/Versions.props b/eng/Versions.props index 4a738f610f..37168d7341 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -75,6 +75,7 @@ 1.15.0 0.105.0 2.5.1 + 0.0.4 1.12.4 6.0.2 diff --git a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs index d52bc9e66c..dd59a9a118 100644 --- a/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs +++ b/src/Microsoft.ML.GenAI.Core/Extension/ModuleExtension.cs @@ -90,7 +90,7 @@ public static void ToInt8QuantizeModule( /// /// /// - /// Quantized data type, can be "fp4" or "nf4". + /// Quantized data type, can be "fp4" or "nf4". "fp4" means 4-bits floating point (1-bit sign, 2-bit exponent and 1-bit mantissa) and "nf4" means normalized 4-bits floating point, which uses a specialized non-uniform quantization aimed at neural network weight distributions (ranged from -1 to 1). /// Block size for quantization, can be [64, 128, 256, 512, 1024]. The larger the size, the faster the speed and the lower the precision. public static void ToQuantize4BitModule( this T model, diff --git a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj index 233507bcb2..d13592c964 100644 --- a/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj +++ b/src/Microsoft.ML.GenAI.Core/Microsoft.ML.GenAI.Core.csproj @@ -18,7 +18,7 @@ - + diff --git a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs index dcf43d3629..baf2b7f0d1 100644 --- a/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs +++ b/src/Microsoft.ML.GenAI.LLaMA/LlamaForCausalLM.cs @@ -120,12 +120,12 @@ public static LlamaForCausalLM FromPretrained( string configName = "config.json", string checkPointName = "model.safetensors.index.json", bool quantizeToInt8 = false, - bool quantizeToInt4 = false, + bool quantizeTo4Bit = false, int layersOnTargetDevice = -1, ScalarType torchDtype = ScalarType.BFloat16, string targetDevice = "cuda") { - if (layersOnTargetDevice == -1 && quantizeToInt4 == false && quantizeToInt8 == false) + if (layersOnTargetDevice == -1 && quantizeTo4Bit == false && quantizeToInt8 == false) { return FromPretrained(modelFolder, configName, checkPointName, torchDtype, targetDevice); } @@ -141,7 +141,7 @@ public static LlamaForCausalLM FromPretrained( { model.ToInt8QuantizeModule(); } - else if (quantizeToInt4) + else if (quantizeTo4Bit) { model.ToQuantize4BitModule(); } @@ -161,7 +161,7 @@ public static LlamaForCausalLM FromPretrained( { model.ToInt8QuantizeModule(); } - else if (quantizeToInt4) + else if (quantizeTo4Bit) { model.ToQuantize4BitModule(); }