Skip to content

Commit

Permalink
Restore deleted tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adam-smnk committed Jan 16, 2025
1 parent a1e54ce commit ad31e0f
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 0 deletions.
50 changes: 50 additions & 0 deletions test/BF16/Integration/matmul-pbf16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// RUN: tpp-run %s -print \
// RUN: -e entry -entry-point-result=void | \
// RUN: FileCheck %s

#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>

func.func @matmultpp(%A: memref<4x8xbf16>,
%B: memref<4x4x2xbf16>, %C: memref<4x4xbf16>) {
%expanded = memref.expand_shape %A [[0], [1, 2]] output_shape [4, 4, 2]
: memref<4x8xbf16> into memref<4x4x2xbf16>
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%expanded, %B : memref<4x4x2xbf16>, memref<4x4x2xbf16>)
outs(%C : memref<4x4xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
return
}

func.func @entry() {
%c0 = arith.constant 0 : index
%f0 = arith.constant 1.0 : bf16
%da = memref.alloc() :memref<4x8xbf16>
linalg.fill ins(%f0 : bf16) outs(%da : memref<4x8xbf16>)
// Call kernel.
%0 = memref.alloc() : memref<4x4x2xbf16>
linalg.fill ins(%f0:bf16) outs (%0: memref<4x4x2xbf16>)
%D = memref.alloc() : memref<4x4xbf16>
%zero = arith.constant 0.0 : bf16
linalg.fill ins(%zero : bf16) outs(%D:memref<4x4xbf16>)
call @matmultpp(%da, %0, %D)
: (memref<4x8xbf16>, memref<4x4x2xbf16>, memref<4x4xbf16>)->()

//
// CHECK:( ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ), ( 8, 8, 8, 8 ) )
//
%d1 = arith.constant -1.0 : bf16

%v0 = vector.transfer_read %D[%c0, %c0], %d1 : memref<4x4xbf16>, vector<4x4xbf16>
%f1 = arith.extf %v0:vector<4x4xbf16> to vector<4x4xf32>
vector.print %f1 : vector<4x4xf32>

return
}
137 changes: 137 additions & 0 deletions test/BF16/Integration/mlp-all-bf16-tpprun.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// RUN: tpp-run %s \
// RUN: -e entry -entry-point-result=void

memref.global "private" constant @arg1 : memref<128x512x2xbf16> = dense<1.00e+00>
memref.global "private" constant @arg3 : memref<256x1024x2xbf16> = dense<1.00e+00>
memref.global "private" constant @arg5 : memref<512x2048x2xbf16> = dense<1.00e+00>
memref.global "private" constant @arg7 : memref<1024x1000x2xbf16> = dense<1.00e+00>

#map = affine_map<(d0, d1, d2, d3) -> (d1, d3, d0)>
#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2, d0)>
#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)>
#map3 = affine_map<(d0, d1) -> (d0, d1)>
#map4 = affine_map<(d0, d1) -> (d1)>

func.func @entry(%arg0: memref<128x256xbf16>, %arg2: memref<512xbf16>, %arg4: memref<1024xbf16>,
%arg6: memref<2048xbf16>, %arg8: memref<1000xbf16>, %arg9: memref<128x512xbf16>,
%arg10: memref<128x1024xbf16>, %arg11: memref<128x2048xbf16>, %arg12: memref<128x1000xbf16>) {
%c0 = arith.constant 0.0 : bf16
linalg.generic {
indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg2: memref<512xbf16>) outs(%arg9: memref<128x512xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}

%e0 = memref.expand_shape %arg0 [[0], [1, 2]] output_shape [128, 128, 2]
: memref<128x256xbf16> into memref<128x128x2xbf16>
%relayout_arg0 = memref.get_global @arg1:memref<128x512x2xbf16>
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%e0, %relayout_arg0 : memref<128x128x2xbf16>, memref<128x512x2xbf16>)
outs(%arg9 : memref<128x512xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
linalg.generic {
indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg9 : memref<128x512xbf16>) outs(%arg9 : memref<128x512xbf16>) {
^bb0(%in: bf16, %out: bf16):
%2 = arith.maximumf %in, %c0 : bf16
linalg.yield %2 : bf16
}

linalg.generic {
indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg4: memref<1024xbf16>) outs(%arg10: memref<128x1024xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}

%e1 = memref.expand_shape %arg9 [[0], [1, 2]] output_shape [128, 256, 2]
: memref<128x512xbf16> into memref<128x256x2xbf16>
%relayout_arg12 = memref.get_global @arg3:memref<256x1024x2xbf16>
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%e1, %relayout_arg12 : memref<128x256x2xbf16>, memref<256x1024x2xbf16>)
outs(%arg10 : memref<128x1024xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
linalg.generic {
indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg10 : memref<128x1024xbf16>) outs(%arg10 : memref<128x1024xbf16>) {
^bb0(%in: bf16, %out: bf16):
%2 = arith.maximumf %in, %c0 : bf16
linalg.yield %2 : bf16
}

linalg.generic {
indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg6: memref<2048xbf16>) outs(%arg11: memref<128x2048xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}

%relayout_arg11 = memref.get_global @arg5:memref<512x2048x2xbf16>
%e2 = memref.expand_shape %arg10 [[0], [1, 2]] output_shape [128, 512, 2]
: memref<128x1024xbf16> into memref<128x512x2xbf16>
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%e2, %relayout_arg11 : memref<128x512x2xbf16>, memref<512x2048x2xbf16>)
outs(%arg11 : memref<128x2048xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
linalg.generic {
indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg11 : memref<128x2048xbf16>) outs(%arg11 : memref<128x2048xbf16>) {
^bb0(%in: bf16, %out: bf16):
%2 = arith.maximumf %in, %c0 : bf16
linalg.yield %2 : bf16
}

linalg.generic {
indexing_maps = [#map4, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg8: memref<1000xbf16>) outs(%arg12: memref<128x1000xbf16>) {
^bb0(%in: bf16, %out: bf16):
linalg.yield %in : bf16
}

%relayout_arg10 = memref.get_global @arg7:memref<1024x1000x2xbf16>
%e3 = memref.expand_shape %arg11 [[0], [1, 2]] output_shape [128, 1024, 2]
: memref<128x2048xbf16> into memref<128x1024x2xbf16>
linalg.generic {
indexing_maps = [#map, #map1, #map2],
iterator_types = ["reduction", "parallel", "parallel", "reduction"]}
ins(%e3, %relayout_arg10 : memref<128x1024x2xbf16>, memref<1024x1000x2xbf16>)
outs(%arg12 : memref<128x1000xbf16>) {
^bb0(%in: bf16, %in_2: bf16, %out: bf16):
%1 = arith.mulf %in, %in_2 : bf16
%2 = arith.addf %out, %1 : bf16
linalg.yield %2 : bf16
}
linalg.generic {
indexing_maps = [#map3, #map3], iterator_types = ["parallel", "parallel"]}
ins(%arg12 : memref<128x1000xbf16>) outs(%arg12 : memref<128x1000xbf16>) {
^bb0(%in: bf16, %out: bf16):
%2 = arith.maximumf %in, %c0 : bf16
linalg.yield %2 : bf16
}

%threshold = arith.constant 1.0 : bf16
%c4 = arith.constant 2.74878e+11: bf16
%interim4 = memref.alloc(): memref<128x1000xbf16>
linalg.fill ins(%c4:bf16) outs(%interim4: memref<128x1000xbf16>)
check.expect_almost_eq(%interim4, %arg12, %threshold): memref<128x1000xbf16>, memref<128x1000xbf16>, bf16
return
}

0 comments on commit ad31e0f

Please sign in to comment.