Skip to content

Commit

Permalink
Add boolean option to do RELU in mlp plugin (iree-org#17058)
Browse files Browse the repository at this point in the history
The linalg pdl script does not match the RELU so sets this option to
false. All other cases we set it to true.
  • Loading branch information
nirvedhmeshram authored Apr 16, 2024
1 parent 872f0b6 commit ace6397
Show file tree
Hide file tree
Showing 13 changed files with 73 additions and 37 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-preprocessing-apply-pdl-patterns{patterns-file=%p/tosa.pdl.mlir})" %s | FileCheck %s

// CHECK-LABEL: stream.executable private @mlp_external_f32_f32_f32_i32_i32_i32_executable
// CHECK-LABEL: stream.executable private @mlp_external_f32_f32_f32_i32_i32_i32_i1_executable
// CHECK: stream.executable.export public @mlp_external_entry_point
// CHECK: builtin.module
// CHECK: func.func private @mlp_external
// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32)
// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32, i1)
// CHECK-SAME: attributes {llvm.bareptr = [true]}
// CHECK: func.func @mlp_external_entry_point
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !stream.binding
Expand All @@ -13,6 +13,7 @@
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i1
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<1x2x4xf32, strided<[8, 4, 1], offset: ?>>
// CHECK-NEXT: %[[STREAM0_BASE:[a-zA-Z0-9_]+]], %[[OFFSET0:[a-zA-Z0-9_]+]],
Expand All @@ -24,18 +25,19 @@
// CHECK-NEXT: %[[STREAM2_BASE:[a-zA-Z0-9_]+]], %[[OFFSET2:[a-zA-Z0-9_]+]],
// CHECK-SAME: = iree_codegen.extract_strided_metadata %[[STREAM2]]
// CHECK: call @mlp_external
// CHECK-SAME: %[[STREAM0_BASE]], %[[OFFSET0]], %[[STREAM1_BASE]], %[[OFFSET1]], %[[STREAM2_BASE]], %[[OFFSET2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
// CHECK-SAME: %[[STREAM0_BASE]], %[[OFFSET0]], %[[STREAM1_BASE]], %[[OFFSET1]], %[[STREAM2_BASE]], %[[OFFSET2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]

// CHECK: func.func @mlp_invocation
// CHECK-SAME: (%[[ARG0:.+]]: tensor<2x4xf32>, %[[ARG1:.+]]: tensor<4x8xf32>)
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i32
// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : i32
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i32
// CHECK-DAG: %[[DORELU:.+]] = arith.constant true
// CHECK-DAG: %[[LHS:.+]] = tosa.reshape %[[ARG0]]
// CHECK-DAG: %[[RHS:.+]] = tosa.reshape %[[ARG1]]
// CHECK: %[[RESULT:.+]] = flow.dispatch
// CHECK-SAME: @mlp_external_f32_f32_f32_i32_i32_i32_executable::@mlp_external_entry_point
// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[C2]], %[[C8]], %[[C4]])
// CHECK-SAME: @mlp_external_f32_f32_f32_i32_i32_i32_i1_executable::@mlp_external_entry_point
// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[C2]], %[[C8]], %[[C4]], %[[DORELU]])
// CHECK: tosa.negate %[[RESULT]]

func.func @mlp_invocation(%lhs: tensor<2x4xf32>, %rhs : tensor<4x8xf32>) -> tensor<2x8xf32> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pdl.pattern @mlp : benefit(1) {
%one_val = pdl.attribute = 1 : index
%two_val = pdl.attribute = 2 : index
%index_type = pdl.type : index
%bool_type = pdl.type : i1
%one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
%one = pdl.result 0 of %one_op
%two_op = pdl.operation "arith.constant" {"value" = %two_val} -> (%index_type : !pdl.type)
Expand All @@ -92,11 +93,15 @@ pdl.pattern @mlp : benefit(1) {
%k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
%k_i32 = pdl.result 0 of %k_i32_op

%true_val = pdl.attribute = 1 : i1
%do_relu_op = pdl.operation "arith.constant" {"value" = %true_val} -> (%bool_type : !pdl.type)
%do_relu = pdl.result 0 of %do_relu_op

%replaced_values_dims = pdl.range : !pdl.range<value>
%input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
%replaced_value = pdl.result 0 of %relu
%replaced_values = pdl.range %replaced_value : !pdl.value
%other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
%other_operands = pdl.range %m_i32, %n_i32, %k_i32, %do_relu : !pdl.value, !pdl.value, !pdl.value, !pdl.value

// The `rewriteAsFlowDispatch` is a rewrite function that allows
// converting the matched dag into a call to the external function call
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
// RUN: iree-opt --pass-pipeline="builtin.module(iree-preprocessing-apply-pdl-patterns{patterns-file=%p/torch.pdl.mlir}, cse)" %s | FileCheck %s

// CHECK-LABEL: stream.executable private @mlp_external_f32_f32_f32_i32_i32_i32_executable
// CHECK-LABEL: stream.executable private @mlp_external_f32_f32_f32_i32_i32_i32_i1_executable
// CHECK: stream.executable.export public @mlp_external_entry_point
// CHECK: builtin.module
// CHECK: func.func private @mlp_external
// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32)
// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32, i1)
// CHECK-SAME: attributes {llvm.bareptr = [true]}
// CHECK: func.func @mlp_external_entry_point
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !stream.binding
Expand All @@ -13,29 +13,31 @@
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i1
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG11:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG12:[a-zA-Z0-9]+]]: index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG6]], %[[ARG7]]}
// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG7]], %[[ARG8]]}
// CHECK-NEXT: %[[STREAM0_BASE:[a-zA-Z0-9_]+]], %[[OFFSET0:[a-zA-Z0-9_]+]],
// CHECK-SAME: = iree_codegen.extract_strided_metadata %[[STREAM0]]
// CHECK: %[[STREAM1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG8]], %[[ARG9]]}
// CHECK: %[[STREAM1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG9]], %[[ARG10]]}
// CHECK-NEXT: %[[STREAM1_BASE:[a-zA-Z0-9_]+]], %[[OFFSET1:[a-zA-Z0-9_]+]],
// CHECK-SAME: = iree_codegen.extract_strided_metadata %[[STREAM1]]
// CHECK: %[[STREAM2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG10]], %[[ARG11]]}
// CHECK: %[[STREAM2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG11]], %[[ARG12]]}
// CHECK-NEXT: %[[STREAM2_BASE:[a-zA-Z0-9_]+]], %[[OFFSET2:[a-zA-Z0-9_]+]],
// CHECK-SAME: = iree_codegen.extract_strided_metadata %[[STREAM2]]
// CHECK: call @mlp_external
// CHECK-SAME: %[[STREAM0_BASE]], %[[OFFSET0]], %[[STREAM1_BASE]], %[[OFFSET1]], %[[STREAM2_BASE]], %[[OFFSET2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
// CHECK-SAME: %[[STREAM0_BASE]], %[[OFFSET0]], %[[STREAM1_BASE]], %[[OFFSET1]], %[[STREAM2_BASE]], %[[OFFSET2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]

// CHECK: func.func @mlp_invocation
// CHECK-SAME: (%[[LHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>, %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0
// CHECK-DAG: %[[C1:.+]] = arith.constant 1
// CHECK-DAG: %[[DORELU:.+]] = arith.constant true
// CHECK: %[[M:.+]] = tensor.dim %[[LHS]], %[[C0]]
// CHECK: %[[N:.+]] = tensor.dim %[[RHS]], %[[C1]]
// CHECK: %[[K:.+]] = tensor.dim %[[LHS]], %[[C1]]
Expand All @@ -44,8 +46,8 @@
// CHECK: %[[K_I32:.+]] = arith.index_cast %[[K]] : index to i32
// CHECK: %[[K_0:.+]] = tensor.dim %[[RHS]], %[[C0]]
// CHECK: %[[RESULT:.+]] = flow.dispatch
// CHECK-SAME: @mlp_external_f32_f32_f32_i32_i32_i32_executable::@mlp_external_entry_point
// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[M_I32]], %[[N_I32]], %[[K_I32]], %[[M]], %[[K]], %[[K_0]], %[[N]], %[[M]], %[[N]])
// CHECK-SAME: @mlp_external_f32_f32_f32_i32_i32_i32_i1_executable::@mlp_external_entry_point
// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[M_I32]], %[[N_I32]], %[[K_I32]], %[[DORELU]], %[[M]], %[[K]], %[[K_0]], %[[N]], %[[M]], %[[N]])
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[RESULT]] :

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ pdl.pattern @mlp : benefit(1) {
%zero_val = pdl.attribute = 0 : index
%one_val = pdl.attribute = 1 : index
%index_type = pdl.type : index
%bool_type = pdl.type : i1
%zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%index_type : !pdl.type)
%zero = pdl.result 0 of %zero_op
%one_op = pdl.operation "arith.constant" {"value" = %one_val} -> (%index_type : !pdl.type)
Expand All @@ -85,11 +86,15 @@ pdl.pattern @mlp : benefit(1) {
%k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
%k_i32 = pdl.result 0 of %k_i32_op

%true_val = pdl.attribute = 1 : i1
%do_relu_op = pdl.operation "arith.constant" {"value" = %true_val} -> (%bool_type : !pdl.type)
%do_relu = pdl.result 0 of %do_relu_op

%replaced_values_dims = pdl.range %m, %n : !pdl.value, !pdl.value
%input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
%replaced_value = pdl.result 0 of %cast
%replaced_values = pdl.range %replaced_value : !pdl.value
%other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
%other_operands = pdl.range %m_i32, %n_i32, %k_i32, %do_relu : !pdl.value, !pdl.value, !pdl.value, !pdl.value

// The `rewriteAsFlowDispatch` is a rewrite function that allows
// converting the matched dag into a call to the external function call
Expand Down
2 changes: 1 addition & 1 deletion samples/custom_dispatch/cpu/mlp_plugin/mlp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
module @example attributes {hal.device.targets = [#cpu_target]} {

// CHECK-LABEL: EXEC @mlp_invocation
// CHECK: [Plugin]: M = 2, N = 2, K = 2
// CHECK: [Plugin]: M = 2, N = 2, K = 2, doRelu = 1
// CHECK: 2x2xf32=[-12 -0][-0 -12]
func.func @mlp_invocation(%lhs: tensor<?x?xf32>,
%rhs: tensor<?x?xf32>) -> (tensor<?x?xf32>) {
Expand Down
22 changes: 12 additions & 10 deletions samples/custom_dispatch/cpu/mlp_plugin/mlp_linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
// RUN: --input="4x8xf32=[[3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0], [3.0, -3.0, 3.0, -3.0]]" | \
// RUN: FileCheck %s --check-prefix=OUTPUT

// CHECK-LABEL: stream.executable private @mlp_external_f32_f32_f32_i32_i32_i32_executable
// CHECK-LABEL: stream.executable private @mlp_external_f32_f32_f32_i32_i32_i32_i1_executable
// CHECK: stream.executable.export public @mlp_external_entry_point
// CHECK: builtin.module
// CHECK: func.func private @mlp_external
// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32)
// CHECK-SAME: (memref<f32>, index, memref<f32>, index, memref<f32>, index, i32, i32, i32, i1)
// CHECK-SAME: attributes {llvm.bareptr = [true]}
// CHECK: func.func @mlp_external_entry_point
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !stream.binding
Expand All @@ -22,27 +22,29 @@
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9]+]]: i1
// CHECK-SAME: %[[ARG7:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG9:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG11:[a-zA-Z0-9]+]]: index
// CHECK-SAME: %[[ARG12:[a-zA-Z0-9]+]]: index
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG6]], %[[ARG7]]}
// CHECK: %[[STREAM0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG7]], %[[ARG8]]}
// CHECK-NEXT: %[[STREAM0_BASE:[a-zA-Z0-9_]+]], %[[OFFSET0:[a-zA-Z0-9_]+]],
// CHECK-SAME: = iree_codegen.extract_strided_metadata %[[STREAM0]]
// CHECK: %[[STREAM1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG8]], %[[ARG9]]}
// CHECK: %[[STREAM1:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG9]], %[[ARG10]]}
// CHECK-NEXT: %[[STREAM1_BASE:[a-zA-Z0-9_]+]], %[[OFFSET1:[a-zA-Z0-9_]+]],
// CHECK-SAME: = iree_codegen.extract_strided_metadata %[[STREAM1]]
// CHECK: %[[STREAM2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG10]], %[[ARG11]]}
// CHECK: %[[STREAM2:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<?x?xf32, strided<[?, 1], offset: ?>>{%[[ARG11]], %[[ARG12]]}
// CHECK-NEXT: %[[STREAM2_BASE:[a-zA-Z0-9_]+]], %[[OFFSET2:[a-zA-Z0-9_]+]],
// CHECK-SAME: = iree_codegen.extract_strided_metadata %[[STREAM2]]
// CHECK: call @mlp_external
// CHECK-SAME: %[[STREAM0_BASE]], %[[OFFSET0]], %[[STREAM1_BASE]], %[[OFFSET1]], %[[STREAM2_BASE]], %[[OFFSET2]], %[[ARG3]], %[[ARG4]], %[[ARG5]]
// CHECK-SAME: %[[STREAM0_BASE]], %[[OFFSET0]], %[[STREAM1_BASE]], %[[OFFSET1]], %[[STREAM2_BASE]], %[[OFFSET2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]]

// CHECK: util.func public @mlp_invocation
// CHECK-SAME: (%[[ARG0:[a-zA-Z0-9]+]]: !hal.buffer_view, %[[ARG1:[a-zA-Z0-9]+]]: !hal.buffer_view)
// CHECK-DAG: %[[DORELU:.+]] = arith.constant false
// CHECK-DAG: %[[MDIM0:.+]] = hal.buffer_view.dim<%[[ARG0]] : !hal.buffer_view>[0] : index
// CHECK-DAG: %[[MDIM1:.+]] = hal.buffer_view.dim<%[[ARG0]] : !hal.buffer_view>[1] : index
// CHECK-DAG: %[[LHS:.+]] = hal.tensor.import %[[ARG0]] "input0" : !hal.buffer_view -> tensor<?x?xf32>{%[[MDIM0]], %[[MDIM1]]}
Expand All @@ -53,8 +55,8 @@
// CHECK-DAG: %[[N_I32:.+]] = arith.index_cast %[[NDIM1]] : index to i32
// CHECK-DAG: %[[K_I32:.+]] = arith.index_cast %[[MDIM1]] : index to i32
// CHECK: %[[RESULT:.+]] = flow.dispatch
// CHECK-SAME: @mlp_external_f32_f32_f32_i32_i32_i32_executable::@mlp_external_entry_point
// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[M_I32]], %[[N_I32]], %[[K_I32]], %[[MDIM0]], %[[MDIM1]], %[[NDIM0]], %[[NDIM1]], %[[MDIM0]], %[[NDIM1]])
// CHECK-SAME: @mlp_external_f32_f32_f32_i32_i32_i32_i1_executable::@mlp_external_entry_point
// CHECK-SAME: (%[[LHS]], %[[RHS]], %[[M_I32]], %[[N_I32]], %[[K_I32]], %[[DORELU]], %[[MDIM0]], %[[MDIM1]], %[[NDIM0]], %[[NDIM1]], %[[MDIM0]], %[[NDIM1]])
// CHECK: linalg.generic
// CHECK-SAME: ins(%[[RESULT]] :

Expand Down Expand Up @@ -106,5 +108,5 @@ module @example attributes {hal.device.targets = [#cpu_target]} {
} // module

// OUTPUT-LABEL: EXEC @mlp_invocation
// OUTPUT: [Plugin]: M = 2, N = 8, K = 4
// OUTPUT: [Plugin]: M = 2, N = 8, K = 4, doRelu = 0
// OUTPUT: 2x8xf32=[-24 -0 -24 -0 -24 -0 -24 -0][-0 -24 -0 -24 -0 -24 -0 -24]
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
// int32_t M;
// int32_t N;
// int32_t K;
// bool doRelu;
// float *restrict result;
// size_t result_offset;
// };
Expand Down Expand Up @@ -69,6 +70,7 @@ pdl.pattern @mlp : benefit(1) {
// external function call. The values of `%M`, `%N` and `%K` need to
// be generated.
%i32_type = pdl.type : i32
%bool_type = pdl.type : i1
%zero_val = pdl.attribute = 0 : index
%one_val = pdl.attribute = 1 : index
%zero_op = pdl.operation "arith.constant" {"value" = %zero_val} -> (%index_type : !pdl.type)
Expand All @@ -88,11 +90,15 @@ pdl.pattern @mlp : benefit(1) {
%k_i32_op = pdl.operation "arith.index_cast"(%k : !pdl.value) -> (%i32_type : !pdl.type)
%k_i32 = pdl.result 0 of %k_i32_op

%false_val = pdl.attribute = 0 : i1
%do_relu_op = pdl.operation "arith.constant" {"value" = %false_val} -> (%bool_type : !pdl.type)
%do_relu = pdl.result 0 of %do_relu_op

%replaced_values_dims = pdl.range %m, %n : !pdl.value, !pdl.value
%input_values = pdl.range %lhs, %rhs : !pdl.value, !pdl.value
%replaced_value = pdl.result 0 of %matmul
%replaced_values = pdl.range %replaced_value : !pdl.value
%other_operands = pdl.range %m_i32, %n_i32, %k_i32 : !pdl.value, !pdl.value, !pdl.value
%other_operands = pdl.range %m_i32, %n_i32, %k_i32, %do_relu : !pdl.value, !pdl.value, !pdl.value, !pdl.value

// The `rewriteAsFlowDispatch` is a rewrite function that allows
// converting the matched dag into a call to the external function call
Expand Down
9 changes: 6 additions & 3 deletions samples/custom_dispatch/cpu/mlp_plugin/mlp_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,11 @@ static int mlp_external(void* params_ptr, void* context, void* reserved) {
int32_t M;
int32_t N;
int32_t K;
bool doRelu;
} params_t;
const params_t* params = (const params_t*)params_ptr;
fprintf(plugin->file, "[Plugin]: M = %d, N = %d, K = %d\n", params->M,
params->N, params->K);
fprintf(plugin->file, "[Plugin]: M = %d, N = %d, K = %d, doRelu = %d\n",
params->M, params->N, params->K, params->doRelu);
for (int32_t i = 0; i < params->M; i++) {
for (int32_t j = 0; j < params->N; j++) {
float curr_result = 0.0;
Expand All @@ -77,7 +78,9 @@ static int mlp_external(void* params_ptr, void* context, void* reserved) {
get_index(k, j, params->rhs_offset, (size_t)params->N);
curr_result += params->lhs[lhs_index] * params->rhs[rhs_index];
}
curr_result = curr_result < 0.0 ? 0.0 : curr_result;
if (params->doRelu) {
curr_result = curr_result < 0.0 ? 0.0 : curr_result;
}
size_t result_index = get_index(i, j, params->result_offset, params->N);
params->result[result_index] = curr_result;
}
Expand Down
Loading

0 comments on commit ace6397

Please sign in to comment.