-
Notifications
You must be signed in to change notification settings - Fork 31
More generic matcher for BRGEMM
lorenzo chelini edited this page Jul 12, 2023
·
8 revisions
[I][i][J][j] += [I][i][K][k] * [K][k][J][j]
BRGEMM on:
[..][i][..][j] += [..][i][K][k] * [K][k][..][j]
batch = K
lda = K * k
ldb = J * j
ldc = J * j
m = i
n = j
k = small k
Parallelize on the "others" parallel dimensions.
#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
func.func @test(%arg0: tensor<64x32x16x32xf32>, %arg1: tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32> {
%cst = arith.constant dense<1.000000e+00> : tensor<16x32x8x64xf32>
%cst_0 = arith.constant dense<2.000000e+00> : tensor<64x32x8x64xf32>
%cst_1 = arith.constant 0.000000e+00 : f32
%0 = tensor.empty() : tensor<64x32x8x64xf32>
%1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32>
%2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel"]} ins(%arg0, %cst : tensor<64x32x16x32xf32>, tensor<16x32x8x64xf32>) outs(%1 : tensor<64x32x8x64xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%4 = arith.mulf %in, %in_2 : f32
%5 = arith.addf %out, %4 : f32
linalg.yield %5 : f32
} -> tensor<64x32x8x64xf32>
%3 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1, %cst_0 : tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) outs(%1 : tensor<64x32x8x64xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%4 = arith.addf %in, %in_2 : f32
linalg.yield %4 : f32
} -> tensor<64x32x8x64xf32>
return %3 : tensor<64x32x8x64xf32>
}
}