-
Notifications
You must be signed in to change notification settings - Fork 31
TPP in destination passing style
lorenzo chelini edited this page Jul 4, 2023
·
2 revisions
-
Do we plan to have a custom bufferization? Or do we rely on the one-shot one? If we use DPS, we already commit to a bufferization strategy at the tensor level; thus, we may remove our entire bufferization pass. It is unclear what kind of benefit it may bring.
-
Properly define what
outs
means for TPPs.
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
%6 = tensor.empty() : tensor<64x32x8x64xf32>
%7 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%expanded_8, %cst_6 : tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) outs(%6 : tensor<64x32x8x64xf32>) {
^bb0(%in: f32, %in_21: f32, %out: f32):
%46 = arith.mulf %in, %in_21 : f32
linalg.yield %46 : f32
} -> tensor<64x32x8x64xf32>
This pattern is fairly recurrent in Linalg. If we assume to have TPP is DPS we could convert the op as:
%6 = tensor.empty() : tensor<64x32x8x64xf32>
%7 = tpp.add ins(%expanded_8, %cst : tensor<64x32x8x64xf32>, tensor<64x32x8x64xf32>) outs(%6 : tensor<64x32x8x64xf32>)
And the bufferize over %6
:
%6 = memref.alloc() : tensor<64x32x8x64xf32>
%7 = tpp.add ins(%expanded_8, %cst : memref<64x32x8x64xf32>, memref<64x32x8x64xf32>) outs(%6 : memref<64x32x8x64xf32>)
This is not optimal. As a solution we could "canonicalize" the linalg IR to drop unused output, and use %expanded_8
as a destination.
- Unclear how to express DPS operation using non-DPS ones without loosing performance. The example, is upstream
tensor.unpack
decomposition. For example, consider:
func.func @entry(%arg0: tensor<64x64xf32>, %arg1: tensor<64x64xf32>, %arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
%0 = linalg.matmul ins(%arg0, %arg1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%arg2 : tensor<64x64xf32>) -> tensor<64x64xf32>
return %0 : tensor<64x64xf32>
}
It is packed as:
...
%3 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%pack, %pack_0 : tensor<2x2x32x32xf32>, tensor<2x2x32x32xf32>) outs(%pack_1 : tensor<2x2x32x32xf32>) {
^bb0(%in: f32, %in_2: f32, %out: f32):
%4 = arith.mulf %in, %in_2 : f32
%5 = arith.addf %out, %4 : f32
linalg.yield %5 : f32
} -> tensor<2x2x32x32xf32>
// So far so good, DPS is respected as `tensor.unpack` "write" into %arg2.
%unpack = tensor.unpack %3 inner_dims_pos = [0, 1] inner_tiles = [32, 32] into %arg2 : tensor<2x2x32x32xf32> -> tensor<64x64xf32>
But when decomposing tensor.unpack
:
// `tensor.unpack` decomposition. Legit rewrite at tensor level but leaks memory.
%alloc_4 = memref.alloc() {alignment = 64 : i64} : memref<2x32x2x32xf32>
linalg.transpose ins(%alloc : memref<2x2x32x32xf32>) outs(%alloc_4 : memref<2x32x2x32xf32>) permutation = [0, 2, 1, 3]
%collapse_shape = memref.collapse_shape %alloc_4 [[0, 1], [2, 3]] : memref<2x32x2x32xf32> into memref<64x64xf32
// leak alloc_4 (`collapse_shape` should "write" into %arg2).
If we move to DPS here the todo list:
- update parser and printer (https://github.com/plaidml/tpp-mlir/blob/2f8937366774f10bd5e3a423ba21f6fa62cebbcb/lib/TPP/Dialect/Tpp/TppOps.cpp#L45) and (https://github.com/plaidml/tpp-mlir/blob/2f8937366774f10bd5e3a423ba21f6fa62cebbcb/lib/TPP/Dialect/Tpp/TppOps.cpp#L118). We may want to use custom assembly and remove the C++ code.
- Update the tensor builder in TPP and align with with the Memref builder (https://github.com/plaidml/tpp-mlir/blob/2f8937366774f10bd5e3a423ba21f6fa62cebbcb/lib/TPP/Dialect/Tpp/TppOps.cpp#L146)
- Update matchers in TppUtils (i.e.,
isIdentityOp
,isTppUnary
andisTppBinary
) - Update the pass
ConvertLinalgToTpp.cpp
- Update tests:
TPP_OPT :: BF16/brgemm-vnni.mlir
TPP_OPT :: BF16/matmul-vnni.mlir
TPP_OPT :: Conversion/LinalgToTpp/linalg-to-tpp-tensor.mlir
TPP_OPT :: Dialect/Tpp/tpp-ops-tensor.mlir
TPP_OPT :: Integration/matmul_64x64x64.mlir
TPP_OPT :: Integration/mlp-fp32-1layer-512.mlir
TPP_OPT :: Integration/pack-unpack-conversion.mlir
TPP_OPT :: Integration/packed-convolution.mlir
TPP_OPT :: Integration/packed-matmul.mlir
TPP_OPT :: Models/mobilenet-without-batchnorm.mlir
TPP_OPT :: Models/multi-head-attention.mlir
TPP_OPT :: Models/resnet50-bottleneck-block.mlir
TPP_OPT :: Models/resnet50v1.mlir
TPP_OPT :: Passes/DefaultPipeline/linalg-tensor.mlir
TPP_OPT :: Passes/DefaultPipeline/tpp-conversion.mlir
TPP_OPT :: Passes/DefaultPipeline/tpp-mapping.mlir
TPP_OPT :: Passes/DefaultPipeline/tpp-to-loops.mlir
TPP_OPT :: Passes/DefaultPipeline/vnni.mlir
TPP_OPT :: Passes/pass-combine-tpp.mlir
TPP_OPT :: Passes/pass-map-to-brgemm.mlir