Skip to content

Commit aa7d10f

Browse files
committed
Get rid of constants from symmetry test
1 parent f3166ea commit aa7d10f

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed
Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
// RUN: enzymexlamlir-opt --enzyme-hlo-generate-td="patterns=transpose_symmetric_simplify" --transform-interpreter --enzyme-hlo-remove-transform %s | FileCheck %s
22

3-
func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
3+
func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
44
%alpha = stablehlo.constant dense<2.0> : tensor<f32>
55
%beta = stablehlo.constant dense<3.0> : tensor<f32>
6-
%c = stablehlo.constant dense<[[4.0, 3.0], [3.0, 4.0]]> : tensor<2x2xf32>
7-
%0 = enzymexla.lapack.symm %c, %arg0, %arg1, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
8-
%1 = stablehlo.subtract %0, %c : tensor<2x2xf32>
9-
%2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
10-
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
11-
return %3 : tensor<2x2xf32>
6+
%0 = enzymexla.lapack.symm %arg0, %arg1, %arg2, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
7+
%1 = enzymexla.lapack.symm %arg2, %arg1, %arg0, %alpha, %beta {side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
8+
%2 = stablehlo.subtract %1, %0 : tensor<2x2xf32>
9+
%3 = stablehlo.dot_general %2, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
10+
%4 = stablehlo.transpose %3, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
11+
return %4 : tensor<2x2xf32>
1212
}
1313

14-
// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
14+
// CHECK: func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2x2xf32>) -> tensor<2x2xf32> {
1515
// CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32>
1616
// CHECK-NEXT: %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<f32>
17-
// CHECK-NEXT: %cst_1 = stablehlo.constant {enzymexla.guaranteed_symmetric = true} dense<{{\[\[}}4.000000e+00, 3.000000e+00], [3.000000e+00, 4.000000e+00{{\]\]}}> : tensor<2x2xf32>
18-
// CHECK-NEXT: %0 = enzymexla.lapack.symm %cst_1, %arg0, %arg1, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
19-
// CHECK-NEXT: %1 = stablehlo.subtract %0, %cst_1 {enzymexla.guaranteed_symmetric = true} : tensor<2x2xf32>
20-
// CHECK-NEXT: %2 = stablehlo.dot_general %1, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
21-
// CHECK-NEXT: return %2 : tensor<2x2xf32>
17+
// CHECK-NEXT: %0 = enzymexla.lapack.symm %arg0, %arg1, %arg2, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
18+
// CHECK-NEXT: %1 = enzymexla.lapack.symm %arg2, %arg1, %arg0, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
19+
// CHECK-NEXT: %2 = stablehlo.subtract %1, %0 : tensor<2x2xf32>
20+
// CHECK-NEXT: %3 = stablehlo.dot_general %2, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
21+
// CHECK-NEXT: return %3 : tensor<2x2xf32>
2222
// CHECK-NEXT: }

0 commit comments

Comments
 (0)