Skip to content

Commit dd97e83

Browse files
committedFeb 11, 2025
cache context
1 parent 9828c50 commit dd97e83

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed
 

‎TorchSharp.BitsAndBytes.Benchmark/CudaBenchmark.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -32,34 +32,34 @@ public void Setup()
3232
(quantizedTensor, absMax, _, _) = BitsAndByteUtils.Quantize4Bit(b, "fp4", blockSize);
3333
}
3434

35-
//[Benchmark]
35+
[Benchmark]
3636
public void Quantize4Bit()
3737
{
3838
var result = BitsAndByteUtils.Quantize4Bit(a1, quantizedDType, blockSize);
3939
}
4040

41-
//[Benchmark]
41+
[Benchmark]
4242
public void Dequantize4Bit()
4343
{
4444
var (quantizedTensor, absMax, _, n) = BitsAndByteUtils.Quantize4Bit(a1, quantizedDType, blockSize);
4545
var result = BitsAndByteUtils.Dequantize4Bit(quantizedTensor, absMax, ScalarType.Float32, quantizedDType, n, a1.shape, blockSize);
4646
}
4747

48-
//[Benchmark]
48+
[Benchmark]
4949
public void GEMV_4Bit_FP4()
5050
{
5151
using var input = torch.rand(new long[] { 1, dim }, dtype: ScalarType.Float32).cuda();
5252
using var result = BitsAndByteUtils.Gemv4Bit(input, quantizedTensor, [4*dim, dim], absMax, blockSize, quantizedDType);
5353
}
5454

55-
//[Benchmark]
55+
[Benchmark]
5656
public void GEMV_4Bit_NF4()
5757
{
5858
using var input = torch.rand(new long[] { 1, dim }, dtype: ScalarType.Float32).cuda();
5959
using var result = BitsAndByteUtils.Gemv4Bit(input, quantizedTensor, [4 * dim, dim], absMax, blockSize, "nf4");
6060
}
6161

62-
//[Benchmark]
62+
[Benchmark]
6363
public void GEMV_FP32()
6464
{
6565
using var input = torch.rand([1, dim], dtype: ScalarType.Float32).cuda();
@@ -74,7 +74,7 @@ public void GEMM_INT8()
7474
using var result = Function.Int8GEMM(input, weight);
7575
}
7676

77-
//[Benchmark]
77+
[Benchmark]
7878
public void GEMM_FP32()
7979
{
8080
using var input = torch.randint(-128, 127, new long[] { 1, dim }, dtype: ScalarType.Float32).cuda();
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
using BenchmarkDotNet.Running;
22
using TorchSharp.BitsAndBytes.Benchmark;
3-
new CudaBenchmark().GEMM_INT8();
43
BenchmarkRunner.Run<CudaBenchmark>();

‎TorchSharp.BitsAndBytes/Function.cs

+16-6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace TorchSharp.BitsAndBytes;
99

1010
public class Function
1111
{
12+
private static readonly Lazy<Dictionary<int, IntPtr>> _context = new(() => new Dictionary<int, IntPtr>());
1213
/// <summary>
1314
/// Integer General Matrix Multiplication (IGEMM) for 8-bit integer data types.
1415
/// </summary>
@@ -24,7 +25,7 @@ public static Tensor Int8GEMM(
2425
bool transposeInput = false)
2526
{
2627
var sout = BitsAndByteUtils.CheckMatmul(input, weight, transposeWeight, transposeInput);
27-
var @out = torch.zeros((long[])[.. sout], dtype: torch.int32, device: input.device);
28+
var result = torch.zeros((long[])[.. sout], dtype: torch.int32, device: input.device);
2829
if (input.shape.Length == 3 && weight.shape.Length == 3)
2930
{
3031
if (input.shape[0] == weight.shape[0] && input.shape[2] == weight.shape[1])
@@ -130,16 +131,25 @@ public static Tensor Int8GEMM(
130131
ldc = m;
131132
}
132133

133-
var context = BitsAndBytesCudaNative.get_context();
134+
IntPtr context;
135+
if (_context.Value.TryGetValue(input.device_index, out var ctx))
136+
{
137+
context = ctx;
138+
}
139+
else
140+
{
141+
context = BitsAndBytesCudaNative.get_context();
142+
_context.Value[input.device_index] = context;
143+
}
144+
134145
var A = LibTorchNativeMethod.THSStorage_data_ptr(input.Handle);
135146
var B = LibTorchNativeMethod.THSStorage_data_ptr(weight.Handle);
136-
var C = LibTorchNativeMethod.THSStorage_data_ptr(@out.Handle);
137-
147+
var C = LibTorchNativeMethod.THSStorage_data_ptr(result.Handle);
138148
BitsAndBytesCudaNative.cigemm(
139149
context: context,
140150
transposeA: transposeWeight, // cuBLAS expects column major, but PyTorch is row major
141151
transposeB: transposeInput, // So we have to transpose A and B
142-
m: m,
152+
m: m,
143153
n: n,
144154
k: k,
145155
A: B, // out_T = B_T @ A_T
@@ -148,7 +158,7 @@ public static Tensor Int8GEMM(
148158
lda: lda,
149159
ldb: ldb,
150160
ldc: ldc);
161+
return result;
151162

152-
return @out;
153163
}
154164
}

0 commit comments

Comments
 (0)