-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[CIR] Implement folder for VecCmpOp #143322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-clang @llvm/pr-subscribers-clangir Author: Amr Hesham (AmrDeveloper) ChangesThis change adds a folder for the VecCmpOp Issue #136487 Full diff: https://github.com/llvm/llvm-project/pull/143322.diff 4 Files Affected:
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 038a59b8ff4eb..6f0957b263407 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2154,6 +2154,8 @@ def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
`(` $kind `,` $lhs `,` $rhs `)` `:` qualified(type($lhs)) `,`
qualified(type($result)) attr-dict
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index bfd3a0a62a8e7..bb505d6b386ff 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1579,6 +1579,109 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
return elements[index];
}
+//===----------------------------------------------------------------------===//
+// VecCmpOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::VecCmpOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute lhs = adaptor.getLhs();
+ mlir::Attribute rhs = adaptor.getRhs();
+ if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
+ !mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
+ return {};
+
+ auto lhsVecAttr = mlir::cast<cir::ConstVectorAttr>(lhs);
+ auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs);
+
+ auto inputElemTy =
+ mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType();
+ if (!mlir::isa<cir::IntType>(inputElemTy) &&
+ !mlir::isa<cir::CIRFPTypeInterface>(inputElemTy))
+ return {};
+
+ cir::CmpOpKind opKind = adaptor.getKind();
+ mlir::ArrayAttr lhsVecElhs = lhsVecAttr.getElts();
+ mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts();
+ uint64_t vecSize = lhsVecElhs.size();
+
+ auto resultVecTy = mlir::cast<cir::VectorType>(getType());
+
+ SmallVector<mlir::Attribute, 16> elements(vecSize);
+ for (uint64_t i = 0; i < vecSize; i++) {
+ mlir::Attribute lhsAttr = lhsVecElhs[i];
+ mlir::Attribute rhsAttr = rhsVecElhs[i];
+
+ int cmpResult = 0;
+ switch (opKind) {
+ case cir::CmpOpKind::lt: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::le: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() <=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() <=
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::gt: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::ge: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() >=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() >=
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::eq: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() ==
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() ==
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ case cir::CmpOpKind::ne: {
+ if (mlir::isa<cir::IntAttr>(lhsAttr)) {
+ cmpResult = mlir::cast<cir::IntAttr>(lhsAttr).getSInt() !=
+ mlir::cast<cir::IntAttr>(rhsAttr).getSInt();
+ } else {
+ cmpResult = mlir::cast<cir::FPAttr>(lhsAttr).getValue() !=
+ mlir::cast<cir::FPAttr>(rhsAttr).getValue();
+ }
+ break;
+ }
+ }
+
+ elements[i] = cir::IntAttr::get(resultVecTy.getElementType(), cmpResult);
+ }
+
+ return cir::ConstVectorAttr::get(
+ getType(), mlir::ArrayAttr::get(getContext(), elements));
+}
+
//===----------------------------------------------------------------------===//
// VecShuffle
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index 33881c69eec5f..65c7bf7158883 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -142,7 +142,7 @@ void CIRCanonicalizePass::runOnOperation() {
// Many operations are here to perform a manual `fold` in
// applyOpPatternsGreedily.
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
- VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
+ VecExtractOp, VecShuffleDynamicOp, VecTernaryOp, VecCmpOp>(op))
ops.push_back(op);
});
diff --git a/clang/test/CIR/Transforms/vector-cmp-fold.cir b/clang/test/CIR/Transforms/vector-cmp-fold.cir
new file mode 100644
index 0000000000000..b207fc08748e2
--- /dev/null
+++ b/clang/test/CIR/Transforms/vector-cmp-fold.cir
@@ -0,0 +1,227 @@
+// RUN: cir-opt %s -cir-canonicalize -o - -split-input-file | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(eq, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(ne, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(le, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<3> : !s32i, #cir.int<5> : !s32i, #cir.int<7> : !s32i]> : !cir.vector<4 x !s32i>
+ %vec_2 = cir.const #cir.const_vector<[#cir.int<2> : !s32i, #cir.int<4> : !s32i, #cir.int<6> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
+ %new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(eq, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(ne, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(lt, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(le, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<1> : !s32i,
+ // CHECK-SAME: #cir.int<1> : !s32i, #cir.int<1> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(gt, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
+
+// -----
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ %vec_1 = cir.const #cir.const_vector<[#cir.fp<1.000000e+00> : !cir.float, #cir.fp<2.000000e+00>
+ : !cir.float, #cir.fp<3.000000e+00> : !cir.float, #cir.fp<4.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %vec_2 = cir.const #cir.const_vector<[#cir.fp<5.000000e+00> : !cir.float, #cir.fp<6.000000e+00>
+ : !cir.float, #cir.fp<7.000000e+00> : !cir.float, #cir.fp<8.000000e+00> : !cir.float]> : !cir.vector<4 x !cir.float>
+ %new_vec = cir.vec.cmp(ge, %vec_1, %vec_2) : !cir.vector<4 x !cir.float>, !cir.vector<4 x !s32i>
+ cir.return %new_vec : !cir.vector<4 x !s32i>
+ }
+
+ // CHECK: cir.func @fold_cmp_vector_op_test() -> !cir.vector<4 x !s32i> {
+ // CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<0> : !s32i, #cir.int<0> : !s32i,
+ // CHECK-SAME: #cir.int<0> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
+ // CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
+}
|
@@ -142,7 +142,7 @@ void CIRCanonicalizePass::runOnOperation() { | |||
// Many operations are here to perform a manual `fold` in | |||
// applyOpPatternsGreedily. | |||
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp, | |||
VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op)) | |||
VecExtractOp, VecShuffleDynamicOp, VecTernaryOp, VecCmpOp>(op)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you put these in lexical order?
auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs); | ||
|
||
auto inputElemTy = | ||
mlir::cast<cir::VectorType>(lhsVecAttr.getType()).getElementType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use auto here. The cast gives the misleading impression that it isn't needed, but this is actually the result of getElementType(), right?
if (!mlir::isa<cir::IntType>(inputElemTy) && | ||
!mlir::isa<cir::CIRFPTypeInterface>(inputElemTy)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!mlir::isa<cir::IntType>(inputElemTy) && | |
!mlir::isa<cir::CIRFPTypeInterface>(inputElemTy)) | |
if (!isAnyIntegerOrFloatingPointType(inputElemTy)) |
mlir::Attribute lhs = adaptor.getLhs(); | ||
mlir::Attribute rhs = adaptor.getRhs(); | ||
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) || | ||
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs)) | ||
return {}; | ||
|
||
auto lhsVecAttr = mlir::cast<cir::ConstVectorAttr>(lhs); | ||
auto rhsVecAttr = mlir::cast<cir::ConstVectorAttr>(rhs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
simplify with dyn_cast_if_present
mlir::ArrayAttr rhsVecElhs = rhsVecAttr.getElts(); | ||
uint64_t vecSize = lhsVecElhs.size(); | ||
|
||
auto resultVecTy = mlir::cast<cir::VectorType>(getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isa cast needed here?
for (uint64_t i = 0; i < vecSize; i++) { | ||
mlir::Attribute lhsAttr = lhsVecElhs[i]; | ||
mlir::Attribute rhsAttr = rhsVecElhs[i]; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to refactor to check mlir::isa<cir::IntAttr>(lhsAttr)
only once based on that obtain directly lhsValue
and rhsValue
and use those in subsequent cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I eliminated the check to be only once. I was thinking of eliminating the casts and creating a lambda or helper function, but not sure if it's worth
b748be2
to
78af912
Compare
uint64_t vecSize = lhsVecElhs.size(); | ||
|
||
SmallVector<mlir::Attribute, 16> elements(vecSize); | ||
bool isIntAttr = vecSize ? mlir::isa<cir::IntAttr>(lhsVecElhs[0]) : false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bool isIntAttr = vecSize ? mlir::isa<cir::IntAttr>(lhsVecElhs[0]) : false; | |
bool isIntAttr = vecSize && mlir::isa<cir::IntAttr>(lhsVecElhs[0]); |
78af912
to
655fd1f
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
41c53fb
to
6560137
Compare
This change adds a folder for the VecCmpOp
Issue #136487