Skip to content

Commit

Permalink
WIP update vnni matchers
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jan 13, 2025
1 parent 9514fc2 commit 2e0294d
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion test/BF16/Integration/tpp-run-splat-shape.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func.func @entry(%arg0: tensor<4x8x8x8xbf16>, %output: tensor<4x8x8x8xbf16>) ->
// due to compile time packing.
// CHECK-NOT: memref.global "private" constant @__constant_{{.*}}: memref<8x8xbf16>
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<4x8x8x8xbf16>
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x4x8x{{[2|4|8]}}xbf16>
// CHECK-DAG: memref.global "private" constant @__constant_{{.*}}: memref<8x8x{{[4|2]}}x8x{{2|4}}xbf16>
// CHECK: xsmm_brgemm_invoke
// CHECK: xsmm_binary_invoke
// CHECK: xsmm_unary_invoke
8 changes: 4 additions & 4 deletions test/BF16/brgemm-tpp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ func.func @brgemm(%arg0: tensor<32x4x4xbf16>, %arg1: tensor<32x4x4xbf16>,
// CHECK-LABEL: brgemm
// CHECK-SAME: %[[ARG0:.+]]: tensor<32x4x4xbf16>, %[[ARG1:.+]]: tensor<32x4x4xbf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<4x4xbf16>
// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape [32, 4, 2, 2] : tensor<32x4x4xbf16> into tensor<32x4x2x2xbf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x2xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [1] inner_tiles = [2]
// CHECK-SAME: into %[[EMPTY]] : tensor<32x4x4xbf16> -> tensor<32x2x4x2xbf16>
// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape{{.*}}: tensor<32x4x4xbf16> into tensor<32x4x{{2|1}}x{{2|4}}xbf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x{{2|4}}xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [1] inner_tiles = [{{2|4}}]
// CHECK-SAME: into %[[EMPTY]] : tensor<32x4x4xbf16> -> tensor<32x{{2|1}}x4x{{2|4}}xbf16>
// CHECK: %{{.+}} = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]
Expand Down
16 changes: 8 additions & 8 deletions test/BF16/brgemm-vnni.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ func.func @brgemm(%arg0: tensor<32x4x4xbf16>, %arg1: tensor<32x4x4xbf16>,
// CHECK-LABEL: brgemm
// CHECK-SAME: %[[ARG0:.+]]: tensor<32x4x4xbf16>, %[[ARG1:.+]]: tensor<32x4x4xbf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<4x4xbf16>
// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape [32, 4, 2, 2] : tensor<32x4x4xbf16> into tensor<32x4x2x2xbf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x2x4x2xbf16>
// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3]] output_shape{{.*}}: tensor<32x4x4xbf16> into tensor<32x4x{{2|1}}x{{2|4}}xbf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<32x{{2|1}}x4x{{2|4}}xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]]
// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [2] into %[[EMPTY]]
// CHECK-SAME: : tensor<32x4x4xbf16> -> tensor<32x2x4x2xbf16>
// CHECK-SAME: inner_dims_pos = [1] inner_tiles = [{{2|4}}] into %[[EMPTY]]
// CHECK-SAME: : tensor<32x4x4xbf16> -> tensor<32x{{2|1}}x4x{{2|4}}xbf16>
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]
Expand Down Expand Up @@ -69,10 +69,10 @@ func.func @prepacked_matmul(%pack: tensor<4x4x32x32xbf16>, %pack_0: tensor<4x4x3
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4x32x32xbf16>, %[[ARG1:.+]]: tensor<4x4x32x32xbf16>,
// CHECK-SAME: %[[ARG2:.+]]: tensor<4x4x32x32xbf16>
// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK-SAME: output_shape [4, 4, 32, 16, 2] : tensor<4x4x32x32xbf16> into tensor<4x4x32x16x2xbf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4x16x32x2xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [2] inner_tiles = [2] into %[[EMPTY]]
// CHECK-SAME: : tensor<4x4x32x32xbf16> -> tensor<4x4x16x32x2xbf16>
// CHECK-SAME: output_shape{{.*}}: tensor<4x4x32x32xbf16> into tensor<4x4x32x{{16|8}}x{{2|4}}xbf16>
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x4x{{16|8}}x32x{{2|4}}xbf16>
// CHECK: %[[PACK:.+]] = tensor.pack %[[ARG1]] inner_dims_pos = [2] inner_tiles = [{{2|4}}] into %[[EMPTY]]
// CHECK-SAME: : tensor<4x4x32x32xbf16> -> tensor<4x4x{{16|8}}x32x{{2|4}}xbf16>
// CHECK: {{.+}} = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "reduction"]
Expand Down
2 changes: 1 addition & 1 deletion test/BF16/matmul-untiled-vnni.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func.func @blocked_matmul(%arg0: tensor<32x64x4x4xbf16>, %arg1: tensor<128x64x4x
// CHECK: %[[ARG0:.*]]: tensor<32x64x4x4xbf16>,
// CHECK: %[[ARG1:.*]]: tensor<128x64x4x4xbf16>,
// CHECK: %[[ARG2:.*]]: tensor<32x128x4x4xbf16>) -> tensor<32x128x4x4xbf16> {
// CHECK: %[[PACKBUF:.*]] = tensor.empty() : tensor<128x64x2x4x2xbf16>
// CHECK: %[[PACKBUF:.*]] = tensor.empty() : tensor<128x64x{{2|1}}x4x{{2|4}}xbf16>
// CHECK: linalg.generic
// CHECK: indexing_maps = [#[[MAP]], #[[MAP1]], #[[MAP2]]]
// CHECK: iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction", "reduction"]
16 changes: 8 additions & 8 deletions test/BF16/matmul-vnni.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ func.func @matmul_static(
// CHECK: %[[PACK_0:.+]] = tensor.pack %[[ARG1]] outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 32]
// CHECK-SAME: into %{{.+}} : tensor<512x1024xbf16> -> tensor<32x16x32x32xbf16>
// CHECK: %[[VNNI_A:.+]] = tensor.expand_shape %[[PACK]] {{\[}}[0], [1], [2], [3, 4]]
// CHECK-SAME: output_shape [8, 16, 32, 16, 2] : tensor<8x16x32x32xbf16> into tensor<8x16x32x16x2xbf16>
// CHECK: %[[EMPTY_2:.+]] = tensor.empty() : tensor<32x16x16x32x2xbf16>
// CHECK: %[[PACK_1:.+]] = tensor.pack %[[PACK_0]] inner_dims_pos = [2] inner_tiles = [2] into %[[EMPTY_2]]
// CHECK-SAME: : tensor<32x16x32x32xbf16> -> tensor<32x16x16x32x2xbf16>
// CHECK-SAME: output_shape{{.*}}: tensor<8x16x32x32xbf16> into tensor<8x16x32x{{16|8}}x{{2|4}}xbf16>
// CHECK: %[[EMPTY_2:.+]] = tensor.empty() : tensor<32x16x{{16|8}}x32x{{2|4}}xbf16>
// CHECK: %[[PACK_1:.+]] = tensor.pack %[[PACK_0]] inner_dims_pos = [2] inner_tiles = [{{2|4}}] into %[[EMPTY_2]]
// CHECK-SAME: : tensor<32x16x32x32xbf16> -> tensor<32x16x{{16|8}}x32x{{2|4}}xbf16>
// CHECK: %{{.+}} = scf.forall (%[[ARG3:.+]], %[[ARG4:.+]]) in (8, 32) shared_outs(%[[ARG5:.+]] = %[[ARG2]])
// CHECK: %[[APPLY:.+]] = affine.apply #[[MAP]](%[[ARG3]])
// CHECK: %[[APPLY_1:.+]] = affine.apply #[[MAP]](%[[ARG4]])
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[VNNI_A]][%[[ARG3]], 0, 0, 0, 0] [1, 16, 32, 16, 2] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<8x16x32x16x2xbf16> to tensor<16x32x16x2xbf16>
// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[PACK_1]][%[[ARG4]], 0, 0, 0, 0] [1, 16, 16, 32, 2] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<32x16x16x32x2xbf16> to tensor<16x16x32x2xbf16>
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[VNNI_A]][%[[ARG3]], 0, 0, 0, 0] [1, 16, 32, {{16|8}}, {{2|4}}] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<8x16x32x{{16|8}}x{{2|4}}xbf16> to tensor<16x32x{{16|8}}x{{2|4}}xbf16>
// CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[PACK_1]][%[[ARG4]], 0, 0, 0, 0] [1, 16, {{16|8}}, 32, {{2|4}}] [1, 1, 1, 1, 1]
// CHECK-SAME: : tensor<32x16x{{16|8}}x32x{{2|4}}xbf16> to tensor<16x{{16|8}}x32x{{2|4}}xbf16>
// CHECK: %[[SLICE_3:.+]] = tensor.extract_slice %[[ARG5]][%[[APPLY]], %[[APPLY_1]]] [32, 32] [1, 1]
// CHECK-SAME: : tensor<256x1024xbf16> to tensor<32x32xbf16>
// CHECK: %[[GEMM:.+]] = linalg.generic
Expand Down

0 comments on commit 2e0294d

Please sign in to comment.