Skip to content

More generic matcher for BRGEMM

lorenzo chelini edited this page Jul 12, 2023 · 8 revisions

[I][i][J][j] += [I][i][K][k] * [K][k][J][j]

BRGEMM on:

[..][i][..][j] += [..][i][K][k] * [K][k][..][j]

batch = K

lda = K * k

ldb = J * j

ldc = J * j

m = i

n = j

k = small k

Parallelize on the "others" parallel dimensions.

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d4, d5)>
#map3 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
module {
  func.func @test(%arg0: tensor<64x32x16x32xf32>, %arg1: tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32> {
    %cst = arith.constant dense<1.000000e+00> : tensor<16x32x8x64xf32>
    %cst_0 = arith.constant dense<2.000000e+00> : tensor<64x32x8x64xf32>
    %cst_1 = arith.constant 0.000000e+00 : f32
    %0 = tensor.empty() : tensor<64x32x8x64xf32>
    %1 = linalg.fill ins(%cst_1 : f32) outs(%0 : tensor<64x32x8x64xf32>) -> tensor<64x32x8x64xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "reduction", "parallel", "parallel"]} ins(%arg0, %cst : tensor<64x32x16x32xf32>, tensor<16x32x8x64xf32>) outs(%1 : tensor<64x32x8x64xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %4 = arith.mulf %in, %in_2 : f32
      %5 = arith.addf %out, %4 : f32
      linalg.yield %5 : f32
    } -> tensor<64x32x8x64xf32>
    %3 = linalg.generic {indexing_maps = [#map3, #map3, #map3], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg1, %cst_0 : tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) outs(%1 : tensor<64x32x8x64xf32>) {
    ^bb0(%in: f32, %in_2: f32, %out: f32):
      %4 = arith.addf %in, %in_2 : f32
      linalg.yield %4 : f32
    } -> tensor<64x32x8x64xf32>
    return %3 : tensor<64x32x8x64xf32>
  }
}