|
1 | 1 | // RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=transpose_symmetric_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s |
2 | 2 |
|
3 | | -func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { |
| 3 | +func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { |
4 | 4 | %alpha = stablehlo.constant dense<2.0> : tensor<f32> |
5 | 5 | %beta = stablehlo.constant dense<3.0> : tensor<f32> |
6 | | - %c = stablehlo.constant dense<[[4.0, 3.0], [3.0, 4.0]]> : tensor<2x2xf32> |
7 | | - %0 = enzymexla.lapack.symm %c, %arg0, %arg1, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32> |
8 | | - %1 = stablehlo.subtract %0, %c : tensor<2x2xf32> |
9 | | - %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> |
10 | | - %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> |
11 | | - return %3 : tensor<2x2xf32> |
| 6 | + %0 = enzymexla.lapack.symm %arg0, %arg1, %arg2, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32> |
| 7 | + %1 = enzymexla.lapack.symm %arg2, %arg1, %arg0, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32> |
| 8 | + %2 = stablehlo.subtract %1, %0 : tensor<2x2xf32> |
| 9 | + %3 = stablehlo.dot_general %2, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> |
| 10 | + %4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32> |
| 11 | + return %4 : tensor<2x2xf32> |
12 | 12 | } |
13 | 13 |
|
14 | | -// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { |
| 14 | +// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> { |
15 | 15 | // CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32> |
16 | 16 | // CHECK-NEXT: %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<f32> |
17 | | -// CHECK-NEXT: %cst_1 = stablehlo.constant {enzymexla.guaranteed_symmetric = true} dense<{{\[\[}}4.000000e+00, 3.000000e+00], [3.000000e+00, 4.000000e+00{{\]\]}}> : tensor<2x2xf32> |
18 | | -// CHECK-NEXT: %0 = enzymexla.lapack.symm %cst_1, %arg0, %arg1, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32> |
19 | | -// CHECK-NEXT: %1 = stablehlo.subtract %0, %cst_1 {enzymexla.guaranteed_symmetric = true} : tensor<2x2xf32> |
20 | | -// CHECK-NEXT: %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> |
21 | | -// CHECK-NEXT: return %2 : tensor<2x2xf32> |
| 17 | +// CHECK-NEXT: %0 = enzymexla.lapack.symm %arg0, %arg1, %arg2, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32> |
| 18 | +// CHECK-NEXT: %1 = enzymexla.lapack.symm %arg2, %arg1, %arg0, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32> |
| 19 | +// CHECK-NEXT: %2 = stablehlo.subtract %1, %0 : tensor<2x2xf32> |
| 20 | +// CHECK-NEXT: %3 = stablehlo.dot_general %2, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> |
| 21 | +// CHECK-NEXT: return %3 : tensor<2x2xf32> |
22 | 22 | // CHECK-NEXT: } |
0 commit comments