@@ -4,7 +4,7 @@ func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2
44 %alpha = stablehlo.constant dense <2.0 > : tensor <f32 >
55 %beta = stablehlo.constant dense <3.0 > : tensor <f32 >
66 %0 = enzymexla.lapack.symm %arg0 , %arg1 , %arg2 , %alpha , %beta {side = #enzymexla.side <left >, uplo = #enzymexla.uplo <U >} : (tensor <2 x2 xf32 >, tensor <2 x2 xf32 >, tensor <2 x2 xf32 >, tensor <f32 >, tensor <f32 >) -> tensor <2 x2 xf32 >
7- %1 = enzymexla.lapack.symm %arg2 , %arg1 , % arg0, %alpha , %beta { side = # enzymexla.side < left >, uplo = #enzymexla.uplo < U > } : (tensor <2 x2 xf32 >, tensor < 2 x 2 x f32 >, tensor < 2 x 2 x f32 >, tensor < f32 >, tensor < f32 >) -> tensor <2 x2 xf32 >
7+ %1 = stablehlo.reshape % arg0 { enzymexla.guaranteed_symmetric = true } : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
88 %2 = stablehlo.subtract %1 , %0 : tensor <2 x2 xf32 >
99 %3 = stablehlo.dot_general %2 , %1 , contracting_dims = [1 ] x [0 ], precision = [DEFAULT , DEFAULT ] : (tensor <2 x2 xf32 >, tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
1010 %4 = stablehlo.transpose %3 , dims = [1 , 0 ] : (tensor <2 x2 xf32 >) -> tensor <2 x2 xf32 >
@@ -15,7 +15,7 @@ func.func @pass1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>, %arg2: tensor<2
1515// CHECK-NEXT: %cst = stablehlo.constant dense<2.000000e+00> : tensor<f32>
1616// CHECK-NEXT: %cst_0 = stablehlo.constant dense<3.000000e+00> : tensor<f32>
1717// CHECK-NEXT: %0 = enzymexla.lapack.symm %arg0, %arg1, %arg2, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U>} : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32>) -> tensor<2x2xf32>
18- // CHECK-NEXT: %1 = enzymexla.lapack.symm %arg2, %arg1, % arg0, %cst, %cst_0 {enzymexla.guaranteed_symmetric = true, side = #enzymexla.side<left>, uplo = #enzymexla.uplo<U> } : (tensor<2x2xf32>, tensor<2x2xf32>, tensor<2x2xf32>, tensor<f32>, tensor<f32 >) -> tensor<2x2xf32>
18+ // CHECK-NEXT: %1 = stablehlo.reshape % arg0 {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>) -> tensor<2x2xf32>
1919// CHECK-NEXT: %2 = stablehlo.subtract %1, %0 : tensor<2x2xf32>
2020// CHECK-NEXT: %3 = stablehlo.dot_general %2, %1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
2121// CHECK-NEXT: return %3 : tensor<2x2xf32>
0 commit comments