diff --git a/pyproject.toml b/pyproject.toml index ddf69724..b18d6367 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,9 +49,7 @@ include = ["xdsl_smt", "tests"] ignore = [ "xdsl_smt/utils/z3_to_dialect.py", "xdsl_smt/utils/integer_to_z3.py", - "xdsl_smt/utils/lower_utils.py", "xdsl_smt/passes/calculate_smt.py", - "xdsl_smt/passes/transfer_lower.py", "xdsl_smt/cli/xdsl_translate.py", "xdsl_smt/cli/transfer_smt_verifier.py", ] diff --git a/tests/filecheck/lower-to-cpp/arith.mlir b/tests/filecheck/lower-to-cpp/arith.mlir new file mode 100644 index 00000000..3d4fa0eb --- /dev/null +++ b/tests/filecheck/lower-to-cpp/arith.mlir @@ -0,0 +1,131 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.addi"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "add_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.subi"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "sub_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.andi"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "and_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.ori"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "or_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.xori"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "xor_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 0 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "eq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 1 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "neq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 2 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "lt_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 3 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "leq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 4 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "gt_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 5 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "geq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + %x = "arith.constant"() {value = 3 : i32} : () -> i32 + "func.return"(%x) : (i32) -> () + }) {"sym_name" = "const_test", "function_type" = () -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32): + "func.return"(%x) : (i32) -> () + }) {"sym_name" = "empty_func_test", "function_type" = (i32) -> i32} : () -> () +}) : () -> () + +// CHECK: int add_test(int x,int y){ +// CHECK-NEXT: int r = x+y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sub_test(int x,int y){ +// CHECK-NEXT: int r = x-y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int and_test(int x,int y){ +// CHECK-NEXT: int r = x&y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int or_test(int x,int y){ +// CHECK-NEXT: int r = x|y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int xor_test(int x,int y){ +// CHECK-NEXT: int r = x^y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int eq_test(int x,int y){ +// CHECK-NEXT: int r = (x==y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int neq_test(int x,int y){ +// CHECK-NEXT: int r = (x!=y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int lt_test(int x,int y){ +// CHECK-NEXT: int r = (xy); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int geq_test(int x,int y){ +// CHECK-NEXT: int r = (x>=y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int const_test(){ +// CHECK-NEXT: int x = 3; +// CHECK-NEXT: return x; +// CHECK-NEXT: } +// CHECK-NEXT: int empty_func_test(int x){ +// CHECK-NEXT: return x; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/special-ops.mlir b/tests/filecheck/lower-to-cpp/special-ops.mlir new file mode 100644 index 00000000..a0680584 --- /dev/null +++ b/tests/filecheck/lower-to-cpp/special-ops.mlir @@ -0,0 +1,62 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%c : i1, %x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.select"(%c, %x, %y) : (i1, !transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "select_test", "function_type" = (i1, !transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%lhs : !transfer.abs_value<[!transfer.integer, !transfer.integer]>, %rhs : !transfer.abs_value<[!transfer.integer, !transfer.integer]>): + %lhs0 = "transfer.get"(%lhs) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %lhs1 = "transfer.get"(%lhs) {index = 1} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %rhs0 = "transfer.get"(%rhs) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %rhs1 = "transfer.get"(%rhs) {index = 1} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %res0 = "transfer.or"(%lhs0, %rhs0) : (!transfer.integer, !transfer.integer) -> !transfer.integer + %res1 = "transfer.and"(%lhs1, %rhs1) : (!transfer.integer, !transfer.integer) -> !transfer.integer + %r = "transfer.make"(%res0, %res1) : (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]> + "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> () + }) {"sym_name" = "kb_and_test", "function_type" = (!transfer.abs_value<[!transfer.integer, !transfer.integer]>, !transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]>} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.abs_value<[!transfer.integer, !transfer.integer]>): + %r = "transfer.get"(%x) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_test", "function_type" = (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.make"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]> + "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> () + }) {"sym_name" = "make_2_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]>} : () -> () + + // "func.func"() ({ + // ^0(%x : !transfer.integer, %y : !transfer.integer, %z : !transfer.integer): + // %r = "transfer.make"(%x, %y, %z) : (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]> + // "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>) -> () + // }) {"sym_name" = "make_3_test", "function_type" = (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>} : () -> () +}) : () -> () + +// CHECK: const APInt select_test(int c,const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = c ? x : y ; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: std::vector kb_and_test(std::vector &lhs,std::vector &rhs){ +// CHECK-NEXT: const APInt lhs0 = lhs[0]; +// CHECK-NEXT: const APInt lhs1 = lhs[1]; +// CHECK-NEXT: const APInt rhs0 = rhs[0]; +// CHECK-NEXT: const APInt rhs1 = rhs[1]; +// CHECK-NEXT: const APInt res0 = lhs0|rhs0; +// CHECK-NEXT: const APInt res1 = lhs1&rhs1; +// CHECK-NEXT: std::vector r = std::vector{res0,res1}; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_test(std::vector &x){ +// CHECK-NEXT: const APInt r = x[0]; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: std::vector make_2_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: std::vector r = std::vector{x,y}; +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir new file mode 100644 index 00000000..793d643e --- /dev/null +++ b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir @@ -0,0 +1,291 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.add"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "add_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sub"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "sub_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.mul"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "mul_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.and"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "and_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.or"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "or_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.xor"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "xor_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.udiv"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "udiv_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sdiv"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "sdiv_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.urem"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "urem_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.srem"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "srem_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.shl"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "shl_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.ashr"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "ashr_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.lshr"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "lshr_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.umin"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "umin_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.smin"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "smin_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.umax"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "umax_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.smax"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "smax_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.get_high_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_high_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.get_low_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.set_high_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "set_high_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.set_low_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "set_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.clear_high_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "clear_high_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.clear_low_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "clear_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () +}) : () -> () + +// CHECK: const APInt add_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x+y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt sub_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x-y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt mul_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x*y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt and_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x&y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt or_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x|y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt xor_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x^y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt udiv_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; +// CHECK-NEXT: if (y == 0) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), -1); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.udiv(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt sdiv_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; +// CHECK-NEXT: if (x.isMinSignedValue() && y == -1) { +// CHECK-NEXT: r = APInt::getSignedMinValue(x.getBitWidth()); +// CHECK-NEXT: } else if (y == 0 && x.isNonNegative()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), -1); +// CHECK-NEXT: } else if (y == 0 && x.isNegative()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 1); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.sdiv(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt urem_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; +// CHECK-NEXT: if (y == 0) { +// CHECK-NEXT: r = x; +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.urem(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt srem_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; +// CHECK-NEXT: if (y == 0) { +// CHECK-NEXT: r = x; +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.srem(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt shl_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; +// CHECK-NEXT: if (y.uge(y.getBitWidth())) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 0); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.shl(y.getZExtValue()); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt ashr_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; +// CHECK-NEXT: if (y.uge(y.getBitWidth()) && x.isSignBitSet()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), -1); +// CHECK-NEXT: } else if (y.uge(y.getBitWidth()) && x.isSignBitClear()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 0); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.ashr(y.getZExtValue()); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt lshr_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; +// CHECK-NEXT: if (y.uge(y.getBitWidth())) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 0); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.lshr(y.getZExtValue()); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt umin_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::umin(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt smin_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::smin(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt umax_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::umax(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt smax_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::smax(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_high_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x.getHiBits(y.getZExtValue()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_low_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x.getLoBits(y.getZExtValue()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt set_high_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.setHighBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.setHighBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt set_low_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.setLowBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.setLowBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt clear_high_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.clearHighBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.clearHighBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt clear_low_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.clearLowBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.clearLowBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir new file mode 100644 index 00000000..293e2232 --- /dev/null +++ b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir @@ -0,0 +1,180 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.umul_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "umul_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.smul_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "smul_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.uadd_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "uadd_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sadd_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sadd_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.ushl_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ushl_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sshl_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sshl_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.intersects"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "intersects_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 0} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "eq_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 1} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "neq_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 2} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "slt_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 3} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sle_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 4} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sgt_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 5} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sge_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 6} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ult_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 7} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ule_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 8} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ugt_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 9} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "uge_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () +}) : () -> () + +// CHECK: int umul_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.umul_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int smul_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.smul_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int uadd_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.uadd_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sadd_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.sadd_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ushl_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.ushl_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sshl_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.sshl_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int intersects_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.intersects(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int eq_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.eq(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int neq_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ne(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int slt_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.slt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sle_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.sle(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sgt_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.sgt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sge_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.sge(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ult_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ult(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ule_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ule(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ugt_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ugt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int uge_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.uge(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir new file mode 100644 index 00000000..7c5a0e8e --- /dev/null +++ b/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir @@ -0,0 +1,131 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_bit_width"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_bw_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countl_zero"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countl_zero_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countr_zero"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countr_zero_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countl_one"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countl_one_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countr_one"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countr_one_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.neg"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "neg_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.clear_sign_bit"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "clear_sign_bit_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.set_sign_bit"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "set_sign_bit_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_all_ones"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_all_ones_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_signed_max_value"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_signed_max_value_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_signed_min_value"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_signed_min_value_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.reverse_bits"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "reverse_bits_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () +}) : () -> () + +// CHECK: const APInt get_bw_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.getBitWidth(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countl_zero_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countl_zero(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countr_zero_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countr_zero(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countl_one_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countl_one(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countr_one_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countr_one(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt neg_test(const APInt &x){ +// CHECK-NEXT: const APInt r = ~x; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt clear_sign_bit_test(const APInt &x){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: r.clearSignBit(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt set_sign_bit_test(const APInt &x){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: r.setSignBit(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_all_ones_test(const APInt &x){ +// CHECK-NEXT: const APInt r = APInt::getAllOnes(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_signed_max_value_test(const APInt &x){ +// CHECK-NEXT: const APInt r = APInt::getSignedMaxValue(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_signed_min_value_test(const APInt &x){ +// CHECK-NEXT: const APInt r = APInt::getSignedMinValue(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt reverse_bits_test(const APInt &x){ +// CHECK-NEXT: const APInt r = x.reverseBits(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/xdsl_smt/cli/cpp_translate.py b/xdsl_smt/cli/cpp_translate.py index 29b13b26..7f63f69f 100644 --- a/xdsl_smt/cli/cpp_translate.py +++ b/xdsl_smt/cli/cpp_translate.py @@ -1,128 +1,87 @@ -#!/usr/bin/env python3 - import argparse -from typing import cast import sys +from pathlib import Path from xdsl.context import Context -from xdsl.ir import Operation +from xdsl.dialects.arith import Arith +from xdsl.dialects.builtin import Builtin, ModuleOp +from xdsl.dialects.func import Func, FuncOp from xdsl.parser import Parser -from xdsl.dialects.arith import Arith -from xdsl.dialects.func import Func -from xdsl_smt.dialects.transfer import Transfer from xdsl_smt.dialects.llvm_dialect import LLVM -from xdsl_smt.passes.transfer_lower import LowerToCpp, addDispatcher, addInductionOps -from xdsl.dialects.func import FuncOp, ReturnOp -from xdsl.dialects.builtin import ( - Builtin, - ModuleOp, - IntegerAttr, - StringAttr, -) - - -def register_all_arguments(arg_parser: argparse.ArgumentParser): - arg_parser.add_argument( - "transfer_functions", type=str, nargs="?", help="path to the transfer functions" - ) - - -def parse_file(ctx: Context, file: str | None) -> Operation: - if file is None: - f = sys.stdin - file = "" - else: - f = open(file) - - parser = Parser(ctx, f.read(), file) - module = parser.parse_op() - return module - +from xdsl_smt.dialects.transfer import Transfer +from xdsl_smt.passes.transfer_lower import lower_to_cpp -def is_transfer_function(func: FuncOp) -> bool: - return "applied_to" in func.attributes +def _register_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Translate MLIR code to C++") -def is_forward(func: FuncOp) -> bool: - if "is_forward" in func.attributes: - forward = func.attributes["is_forward"] - assert isinstance(forward, IntegerAttr) - return forward.value.data == 1 - return False + parser.add_argument( + "-i", + "--input", + type=Path, + default=None, + help="Path to the input MLIR file (defaults to stdin if omitted).", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=None, + help="Path to the output MLIR file (defaults to stdout if omitted).", + ) + parser.add_argument( + "--apint", + action="store_true", + help="Use LLVM APInts for bitvector type lowering", + ) + parser.add_argument( + "--custom_vec", + action="store_true", + help="Use custom vec class for abstract value lowering", + ) + parser.add_argument( + "--llvm_kb", + action="store_true", + help="Use LLVM KnownBits for abstract value lowering", + ) + return parser.parse_args() -def getCounterexampleFunc(func: FuncOp) -> str | None: - if "soundness_counterexample" not in func.attributes: - return None - attr = func.attributes["soundness_counterexample"] - assert isinstance(attr, StringAttr) - return attr.data +def _parse_mlir_module(p: Path | None, ctx: Context) -> ModuleOp: + code = p.read_text() if p else sys.stdin.read() + fname = p.name if p else "" + mod = Parser(ctx, code, fname).parse_op() -def checkFunctionValidity(func: FuncOp) -> bool: - if len(func.function_type.inputs) != len(func.args): - return False - for func_type_arg, arg in zip(func.function_type.inputs, func.args): - if func_type_arg != arg.type: - return False - return_op = func.body.block.last_op - if not (return_op is not None and isinstance(return_op, ReturnOp)): - return False - return return_op.operands[0].type == func.function_type.outputs.data[0] + if isinstance(mod, ModuleOp): + return mod + elif isinstance(mod, FuncOp): + return ModuleOp([mod]) + else: + raise ValueError(f"mlir in '{fname}' is neither a ModuleOp, nor a FuncOp") def main() -> None: - ctx = Context() - arg_parser = argparse.ArgumentParser() - register_all_arguments(arg_parser) - args = arg_parser.parse_args() + args = _register_args() - # Register all dialects + ctx = Context() ctx.load_dialect(Arith) ctx.load_dialect(Builtin) ctx.load_dialect(Func) ctx.load_dialect(Transfer) ctx.load_dialect(LLVM) - # Parse the files - module = parse_file(ctx, args.transfer_functions) - assert isinstance(module, ModuleOp) - - allFuncMapping: dict[str, FuncOp] = {} - forward = False - counterexampleFuncs: set[str] = set() - with open("tmp.cpp", "w") as fout: - LowerToCpp.fout = fout - for func in module.ops: - if isinstance(func, FuncOp): - if is_transfer_function(func): - forward |= is_transfer_function(func) and is_forward(func) - counterexampleFunc = getCounterexampleFunc(func) - if counterexampleFunc is not None: - counterexampleFuncs.add(counterexampleFunc) - allFuncMapping[func.sym_name.data] = func - - # check function validity - if not checkFunctionValidity(func): - print(func.sym_name) - # check function validity + module = _parse_mlir_module(args.input, ctx) + output = args.output.open("w", encoding="utf-8") if args.output else sys.stdout - for counterexample in counterexampleFuncs: - assert counterexample in allFuncMapping - allFuncMapping[counterexample].detach() - del allFuncMapping[counterexample] - for func in module.ops: - if isinstance(func, FuncOp): - allFuncMapping[func.sym_name.data] = func - # HACK: we know the pass won't check that the operation is a module - LowerToCpp(fout).apply(ctx, cast(ModuleOp, func)) - addInductionOps(fout) - addDispatcher(fout, forward) + if args.custom_vec and args.llvm_kb: + raise ValueError("Cannot lower with both custom vectors and LLVM KnownBits") - # printer = Printer(target=Printer.Target.MLIR) - # printer.print_op(module) - - -if __name__ == "__main__": - main() + lower_to_cpp( + module, + output, + use_apint=args.apint, + use_custom_vec=args.custom_vec, + use_llvm_kb=args.llvm_kb, + ) diff --git a/xdsl_smt/dialects/transfer.py b/xdsl_smt/dialects/transfer.py index 28c89f68..45a0a8a1 100644 --- a/xdsl_smt/dialects/transfer.py +++ b/xdsl_smt/dialects/transfer.py @@ -287,6 +287,16 @@ class SAddOverflowOp(PredicateOp): name = "transfer.sadd_overflow" +@irdl_op_definition +class USubOverflowOp(PredicateOp): + name = "transfer.usub_overflow" + + +@irdl_op_definition +class SSubOverflowOp(PredicateOp): + name = "transfer.ssub_overflow" + + @irdl_op_definition class AndOp(BinOp): name = "transfer.and" @@ -785,6 +795,11 @@ class GetSignedMinValueOp(UnaryOp): name = "transfer.get_signed_min_value" +@irdl_op_definition +class GetLimitedValueOp(BinOp): + name = "transfer.get_limited_value" + + Transfer = Dialect( "transfer", [ @@ -829,6 +844,8 @@ class GetSignedMinValueOp(UnaryOp): SAddOverflowOp, UShlOverflowOp, SShlOverflowOp, + USubOverflowOp, + SSubOverflowOp, SelectOp, IsPowerOf2Op, IsAllOnesOp, @@ -845,6 +862,7 @@ class GetSignedMinValueOp(UnaryOp): AddPoisonOp, RemovePoisonOp, ReverseBitsOp, + GetLimitedValueOp, ], [TransIntegerType, AbstractValueType, TupleType], ) diff --git a/xdsl_smt/passes/transfer_lower.py b/xdsl_smt/passes/transfer_lower.py index 7a4416d3..b7862647 100644 --- a/xdsl_smt/passes/transfer_lower.py +++ b/xdsl_smt/passes/transfer_lower.py @@ -1,55 +1,41 @@ -from typing import TextIO -from xdsl.dialects.func import * -from xdsl.pattern_rewriter import * -from functools import singledispatch from dataclasses import dataclass -from xdsl.passes import ModulePass +from typing import TextIO +import sys +from xdsl.dialects.builtin import ModuleOp +from xdsl.dialects.func import FuncOp from xdsl.ir import Operation -from xdsl.context import Context -from ..utils.lower_utils import ( - lowerOperation, - CPP_CLASS_KEY, - lowerDispatcher, - INDUCTION_KEY, - lowerInductionOps, -) - from xdsl.pattern_rewriter import ( - RewritePattern, + GreedyRewritePatternApplier, PatternRewriter, - op_type_rewrite_pattern, PatternRewriteWalker, - GreedyRewritePatternApplier, + RewritePattern, + op_type_rewrite_pattern, ) -from xdsl.dialects import builtin - -autogen = 0 - - -@singledispatch -def transferFunction(op, fout): - pass +from ..utils.lower_utils import ( + lowerOperation, + set_use_apint, + set_use_custom_vec, + set_use_llvm_kb, +) +autogen = 0 funcStr = "" -indent = "\t" -needDispatch: list[FuncOp] = [] -inductionOp: list[FuncOp] = [] -@transferFunction.register -def _(op: Operation, fout): - global needDispatch - global inductionOp +def transfer_func(op: Operation, fout: TextIO): if isinstance(op, ModuleOp): return - # print(lowerDispatcher(needDispatch)) - # fout.write(lowerDispatcher(needDispatch)) if len(op.results) > 0 and op.results[0].name_hint is None: global autogen op.results[0].name_hint = "autogen" + str(autogen) autogen += 1 + if isinstance(op, FuncOp): + for arg in op.args: + if arg.name_hint is None: + arg.name_hint = "autogen" + str(autogen) + autogen += 1 global funcStr funcStr += lowerOperation(op) parentOp = op.parent_op() @@ -57,46 +43,36 @@ def _(op: Operation, fout): funcStr += "}\n" fout.write(funcStr) funcStr = "" - if isinstance(op, FuncOp): - if CPP_CLASS_KEY in op.attributes: - needDispatch.append(op) - if INDUCTION_KEY in op.attributes: - inductionOp.append(op) @dataclass class LowerOperation(RewritePattern): - def __init__(self, fout): + def __init__(self, fout: TextIO): self.fout = fout @op_type_rewrite_pattern - def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter): - transferFunction(op, self.fout) - - -def addInductionOps(fout: TextIO): - global inductionOp - if len(inductionOp) != 0: - fout.write(lowerInductionOps(inductionOp)) - - -def addDispatcher(fout: TextIO, is_forward: bool): - global needDispatch - if len(needDispatch) != 0: - # print(lowerDispatcher(needDispatch)) - fout.write(lowerDispatcher(needDispatch, is_forward)) - - -@dataclass(frozen=True) -class LowerToCpp(ModulePass): - name = "trans_lower" - fout: TextIO = None - - def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: - walker = PatternRewriteWalker( - GreedyRewritePatternApplier([LowerOperation(self.fout)]), - walk_regions_first=False, - apply_recursively=True, - walk_reverse=False, - ) - walker.rewrite_module(op) + def match_and_rewrite(self, op: Operation, _: PatternRewriter): + transfer_func(op, self.fout) + + +def lower_to_cpp( + op: ModuleOp, + fout: TextIO = sys.stdout, + use_apint: bool = False, + use_custom_vec: bool = False, + use_llvm_kb: bool = False, +) -> None: + global autogen + autogen = 0 + + # set options + set_use_apint(use_apint) + set_use_custom_vec(use_custom_vec) + set_use_llvm_kb(use_llvm_kb) + + PatternRewriteWalker( + GreedyRewritePatternApplier([LowerOperation(fout)]), + walk_regions_first=False, + apply_recursively=False, + walk_reverse=False, + ).rewrite_module(op) diff --git a/xdsl_smt/utils/lower_utils.py b/xdsl_smt/utils/lower_utils.py index 86e1c4e5..7eebef6b 100644 --- a/xdsl_smt/utils/lower_utils.py +++ b/xdsl_smt/utils/lower_utils.py @@ -1,52 +1,63 @@ +from functools import singledispatch +from typing import Callable + +import xdsl.dialects.arith as arith +from xdsl.dialects.builtin import IndexType, IntegerType +from xdsl.dialects.func import CallOp, FuncOp, ReturnOp +from xdsl.ir import Attribute, Block, BlockArgument, Operation, SSAValue + from ..dialects.transfer import ( AbstractValueType, - GetOp, - MakeOp, - NegOp, - Constant, + AddPoisonOp, + AShrOp, + ClearHighBitsOp, + ClearLowBitsOp, + ClearSignBitOp, CmpOp, - AndOp, - OrOp, - XorOp, - AddOp, - SubOp, + ConcatOp, + Constant, + ConstRangeForOp, CountLOneOp, CountLZeroOp, CountROneOp, CountRZeroOp, - SetHighBitsOp, - SetLowBitsOp, - GetLowBitsOp, - GetBitWidthOp, - UMulOverflowOp, - SMinOp, - SMaxOp, - UMinOp, - UMaxOp, - TransIntegerType, - ShlOp, - AShrOp, - LShrOp, ExtractOp, - ConcatOp, GetAllOnesOp, - SelectOp, + GetBitWidthOp, + GetHighBitsOp, + GetLowBitsOp, + GetOp, + GetSignedMaxValueOp, + GetSignedMinValueOp, + IntersectsOp, + LShrOp, + MakeOp, + NegOp, NextLoopOp, - ConstRangeForOp, + RemovePoisonOp, RepeatOp, - IntersectsOp, - # FromArithOp, + SAddOverflowOp, + SDivOp, + SelectOp, + SetHighBitsOp, + SetLowBitsOp, + SetSignBitOp, + ShlOp, + SMaxOp, + SMinOp, + SMulOverflowOp, + SRemOp, + SShlOverflowOp, + TransIntegerType, TupleType, - AddPoisonOp, - RemovePoisonOp, + UAddOverflowOp, + UDivOp, + UMaxOp, + UMinOp, + UMulOverflowOp, + URemOp, + UShlOverflowOp, ) -from xdsl.dialects.func import FuncOp, Return, Call -from xdsl.pattern_rewriter import * -from functools import singledispatch -from typing import TypeVar, cast -from xdsl.dialects.builtin import Signedness, IntegerType, IndexType, IntegerAttr -from xdsl.ir import Operation -import xdsl.dialects.arith as arith operNameToCpp = { "transfer.and": "&", @@ -62,15 +73,29 @@ "arith.subi": "-", "transfer.neg": "~", "transfer.mul": "*", + "transfer.udiv": ".udiv", + "transfer.sdiv": ".sdiv", + "transfer.urem": ".urem", + "transfer.srem": ".srem", "transfer.umul_overflow": ".umul_ov", + "transfer.smul_overflow": ".smul_ov", + "transfer.uadd_overflow": ".uadd_ov", + "transfer.sadd_overflow": ".sadd_ov", + "transfer.ushl_overflow": ".ushl_ov", + "transfer.sshl_overflow": ".sshl_ov", "transfer.get_bit_width": ".getBitWidth", "transfer.countl_zero": ".countl_zero", "transfer.countr_zero": ".countr_zero", "transfer.countl_one": ".countl_one", "transfer.countr_one": ".countr_one", + "transfer.get_high_bits": ".getHiBits", "transfer.get_low_bits": ".getLoBits", "transfer.set_high_bits": ".setHighBits", "transfer.set_low_bits": ".setLowBits", + "transfer.clear_high_bits": ".clearHighBits", + "transfer.clear_low_bits": ".clearLowBits", + "transfer.set_sign_bit": ".setSignBit", + "transfer.clear_sign_bit": ".clearSignBit", "transfer.intersects": ".intersects", "transfer.cmp": [ ".eq", @@ -84,7 +109,6 @@ ".ugt", ".uge", ], - # "transfer.fromArith": "APInt", "transfer.make": "{{{0}}}", "transfer.get": "[{0}]", "transfer.shl": ".shl", @@ -92,218 +116,233 @@ "transfer.lshr": ".lshr", "transfer.concat": ".concat", "transfer.extract": ".extractBits", - "transfer.umin": [".ule", "?", ":"], - "transfer.smin": [".sle", "?", ":"], - "transfer.umax": [".ugt", "?", ":"], - "transfer.smax": [".sgt", "?", ":"], + "transfer.umin": "A::APIntOps::umin", + "transfer.smin": "A::APIntOps::smin", + "transfer.umax": "A::APIntOps::umax", + "transfer.smax": "A::APIntOps::smax", "func.return": "return", "transfer.constant": "APInt", "arith.select": ["?", ":"], "arith.cmpi": ["==", "!=", "<", "<=", ">", ">="], "transfer.get_all_ones": "APInt::getAllOnes", + "transfer.get_signed_max_value": "APInt::getSignedMaxValue", + "transfer.get_signed_min_value": "APInt::getSignedMinValue", "transfer.select": ["?", ":"], "transfer.reverse_bits": ".reverseBits", "transfer.add_poison": " ", "transfer.remove_poison": " ", + "comb.add": "+", + "comb.sub": "-", + "comb.mul": "*", + "comb.and": "&", + "comb.or": "|", + "comb.xor": "^", + "comb.divs": ".sdiv", + "comb.divu": ".udiv", + "comb.mods": ".srem", + "comb.modu": ".urem", + "comb.mux": ["?", ":"], + "comb.shrs": ".ashr", + "comb.shru": ".lshr", + "comb.shl": ".shl", + "comb.extract": ".extractBits", + "comb.concat": ".concat", } # transfer.constRangeLoop and NextLoop are controller operations, should be handle specially -unsignedReturnedType = { - CountLOneOp, - CountLZeroOp, - CountROneOp, - CountRZeroOp, - GetBitWidthOp, +# consts +EQ = " = " +END = ";\n" +IDNT = "\t" + +VAL_EXCEEDS_BW = "{1}.uge({1}.getBitWidth())" +RHS_IS_ZERO = "{1} == 0" +RET_ZERO = "{0} = APInt({1}.getBitWidth(), 0)" +RET_ONE = "{0} = APInt({1}.getBitWidth(), 1)" +RET_ONES = "{0} = APInt({1}.getBitWidth(), -1)" +RET_SIGN_MIN_VAL = "{0} = APInt::getSignedMinValue({1}.getBitWidth())" +RET_LHS = "{0} = {1}" + +SHIFT_ACTION = (VAL_EXCEEDS_BW, RET_ZERO) +ASHR_ACTION0 = VAL_EXCEEDS_BW + " && {0}.isSignBitSet()", RET_ONES +ASHR_ACTION1 = VAL_EXCEEDS_BW + " && {0}.isSignBitClear()", RET_ZERO +REM_ACTION = RHS_IS_ZERO, RET_LHS +DIV_ACTION = RHS_IS_ZERO, RET_ONES +SDIV_ACTION0 = ("{0}.isMinSignedValue() && {1} == -1", RET_SIGN_MIN_VAL) +SDIV_ACTION1 = (RHS_IS_ZERO + " && {0}.isNonNegative()", RET_ONES) +SDIV_ACTION2 = (RHS_IS_ZERO + " && {0}.isNegative()", RET_ONE) + +op_to_cons: dict[type[Operation], list[tuple[str, str]]] = { + ShlOp: [SHIFT_ACTION], + LShrOp: [SHIFT_ACTION], + UDivOp: [DIV_ACTION], + URemOp: [REM_ACTION], + SRemOp: [REM_ACTION], + AShrOp: [ASHR_ACTION0, ASHR_ACTION1], + SDivOp: [SDIV_ACTION0, SDIV_ACTION1, SDIV_ACTION2], } -ends = ";\n" -indent = "\t" +# lowering config +use_apint = False +use_custom_vec = False +use_llvm_kb = False + + +def set_use_apint(f: bool) -> None: + global use_apint + use_apint = f + + +def set_use_custom_vec(f: bool) -> None: + global use_custom_vec + use_custom_vec = f + + +def set_use_llvm_kb(f: bool) -> None: + global use_llvm_kb + use_llvm_kb = f + + +# helpers +def get_ret_val(op: Operation) -> str: + ret_val = op.results[0].name_hint + assert ret_val + return ret_val + + +def get_op_names(op: Operation) -> list[str]: + return [oper.name_hint for oper in op.operands if oper.name_hint] + + +def get_operand(op: Operation, idx: int) -> str: + name = op.operands[idx].name_hint + assert name + return name + +def get_op_str(op: Operation) -> str: + op_name = operNameToCpp[op.name] + assert isinstance(op_name, str) + return op_name + + +def lowerType(typ: Attribute, specialOp: Operation | Block | None = None) -> str: + unsigned_ret_type = { + CountLOneOp, + CountLZeroOp, + CountROneOp, + CountRZeroOp, + GetBitWidthOp, + } -def lowerType(typ, specialOp=None): if specialOp is not None: - for op in unsignedReturnedType: + for op in unsigned_ret_type: if isinstance(specialOp, op): return "unsigned" - if isinstance(typ, TransIntegerType): - return "APInt" + + if isinstance(typ, TransIntegerType) or ( + isinstance(typ, IntegerType) and use_apint + ): + return "const APInt" elif isinstance(typ, AbstractValueType) or isinstance(typ, TupleType): fields = typ.get_fields() typeName = lowerType(fields[0]) for i in range(1, len(fields)): assert lowerType(fields[i]) == typeName - return "std::vector<" + typeName + ">" - elif isinstance(typ, IntegerType): - return "int" - elif isinstance(typ, IndexType): - return "int" - assert False and "unsupported type" - - -CPP_CLASS_KEY = "CPPCLASS" -INDUCTION_KEY = "induction" -OPERATION_NO = "operationNo" - - -def lowerInductionOps(inductionOp: list[FuncOp]): - if len(inductionOp) > 0: - functionSignature = """ -{returnedType} {funcName}(ArrayRef<{returnedType}> operands){{ - {returnedType} result={funcName}(operands[0], operands[1]); - for(int i=2;i 0: - returnedType = needDispatch[0].function_type.outputs.data[0] - for func in needDispatch: - if func.function_type.outputs.data[0] != returnedType: - print(func) - print(func.function_type.outputs.data[0]) - assert ( - "we assume all transfer functions have the same returned type" - and False - ) - returnedType = lowerType(returnedType) - funcName = "naiveDispatcher" - # we assume all operands have the same type as expr - # User should tell the generator all operands - if is_forward: - expr = "(Operation* op, std::vector> operands)" + + if use_custom_vec: + return "Vec<" + str(len(fields)) + ">" + elif use_llvm_kb: + assert len(fields) == 2 + return "const llvm::KnownBits" else: - expr = "(Operation* op, std::vector> operands, unsigned operationNo)" - functionSignature = ( - "std::optional<" + returnedType + "> " + funcName + expr + "{{\n{0}}}\n\n" - ) - indent = "\t" - dyn_cast = ( - indent - + "if(auto castedOp=dyn_cast<{0}>(op);castedOp&&{1}){{\n{2}" - + indent - + "}}\n" - ) - return_inst = indent + indent + "return {0}({1});\n" - - def handleOneTransferFunction(func: FuncOp, operationNo: int) -> str: - blockStr = "" - for cppClass in func.attributes[CPP_CLASS_KEY]: - argStr = "" - if INDUCTION_KEY in func.attributes: - argStr = "operands" - else: - if len(func.args) > 0: - argStr = "operands[0]" - for i in range(1, len(func.args)): - argStr += ", operands[" + str(i) + "]" - ifBody = return_inst.format(func.sym_name.data, argStr) - if operationNo == -1: - operationNoStr = "true" - else: - operationNoStr = "operationNo == " + str(operationNo) - blockStr += dyn_cast.format(cppClass.data, operationNoStr, ifBody) - return blockStr - - funcBody = "" - for func in needDispatch: - if is_forward: - funcBody += handleOneTransferFunction(func) - else: - operationNo = func.attributes[OPERATION_NO] - assert isinstance(operationNo, IntegerAttr) - funcBody += handleOneTransferFunction(func, operationNo.value.data) - funcBody += indent + "return {};\n" - return functionSignature.format(funcBody) - - -def isFunctionCall(opName): - return opName[0] == "." + return "std::vector<" + typeName + ">" + elif isinstance(typ, IndexType) or isinstance(typ, IntegerType): + return "int" + raise ValueError(f"unsupported type: {type(typ)}") -def lowerToNonClassMethod(op: Operation): - returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" - expr = "(" - if len(op.operands) > 0: - expr += op.operands[0].name_hint - for i in range(1, len(op.operands)): - expr += "," + op.operands[i].name_hint - expr += ")" - return ( - indent - + returnedType - + " " - + returnedValue - + equals - + operNameToCpp[op.name] - + expr - + ends - ) +def isFunctionCall(opName: str) -> bool: + return opName[0] == "." + + +def lowerToClassMethod( + op: Operation, + castOperand: Callable[[SSAValue | str], str] | None = None, + castResult: Callable[[Operation], str] | None = None, +) -> str: + ret_ty = lowerType(op.results[0].type, op) + ret_val = get_ret_val(op) -def lowerToClassMethod(op: Operation, castOperand=None, castResult=None): - returnedType = lowerType(op.results[0].type, op) if castResult is not None: - returnedValue = op.results[0].name_hint + "_autocast" - else: - returnedValue = op.results[0].name_hint - equals = "=" - expr = op.operands[0].name_hint + operNameToCpp[op.name] + "(" + ret_val += "_autocast" + expr = get_operand(op, 0) + get_op_str(op) + "(" + if castOperand is not None: operands = [castOperand(operand) for operand in op.operands] else: - operands = [operand.name_hint for operand in op.operands] + operands = get_op_names(op) + if len(operands) > 1: expr += operands[1] for i in range(2, len(operands)): expr += "," + operands[i] + expr += ")" - result = indent + returnedType + " " + returnedValue + equals + expr + ends + + if type(op) in op_to_cons: + conds, actions = zip(*op_to_cons[type(op)]) # type: ignore + + og_op_names = get_op_names(op) + conds: list[str] = [cond.format(*og_op_names) for cond in conds] + actions: list[str] = [act.format(ret_val, *og_op_names) for act in actions] + + if_fmt = "if ({cond}) {{\n" + IDNT + IDNT + "{act}" + END + IDNT + "}}" + + ifs = " else ".join( + [if_fmt.format(cond=c, act=a) for c, a in zip(conds, actions)] + ) + + final_else_br = IDNT + IDNT + ret_val + EQ + expr + END + + result = IDNT + ret_ty + " " + ret_val + END + result += IDNT + ifs + " else {\n" + final_else_br + IDNT + "}\n" + + else: + result = IDNT + ret_ty + " " + ret_val + EQ + expr + END + if castResult is not None: return result + castResult(op) + return result @singledispatch -def lowerOperation(op): +def lowerOperation(op: Operation) -> str: returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] - if isFunctionCall(operNameToCpp[op.name]): - expr = operandsName[0] + operNameToCpp[op.name] + "(" + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) + op_str = get_op_str(op) + + if isFunctionCall(op_str): + expr = operandsName[0] + op_str + "(" if len(operandsName) > 1: expr += operandsName[1] for i in range(2, len(operandsName)): expr += "," + operandsName[i] expr += ")" else: - expr = operandsName[0] + operNameToCpp[op.name] + operandsName[1] - result = indent + returnedType + " " + returnedValue + equals + expr + ends - return result + expr = operandsName[0] + op_str + operandsName[1] + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register def _(op: CmpOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) predicate = op.predicate.value.data operName = operNameToCpp[op.name][predicate] expr = operandsName[0] + operName + "(" @@ -312,220 +351,322 @@ def _(op: CmpOp): for i in range(2, len(operandsName)): expr += "," + operandsName[i] expr += ")" - return indent + returnedType + " " + returnedValue + equals + expr + ends + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register -def _(op: arith.Cmpi): +def _(op: arith.CmpiOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) assert len(operandsName) == 2 predicate = op.predicate.value.data operName = operNameToCpp[op.name][predicate] - expr = "(" + operandsName[0] + operName + operandsName[1] - expr += ")" - return indent + returnedType + " " + returnedValue + equals + expr + ends + expr = "(" + operandsName[0] + operName + operandsName[1] + ")" + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register -def _(op: arith.Select): +def _(op: arith.SelectOp): returnedType = lowerType(op.operands[1].type, op) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) operator = operNameToCpp[op.name] expr = "" for i in range(len(operandsName)): expr += operandsName[i] + " " if i < len(operator): expr += operator[i] + " " - return indent + returnedType + " " + returnedValue + equals + expr + ends + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register def _(op: SelectOp): returnedType = lowerType(op.operands[1].type, op) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) operator = operNameToCpp[op.name] expr = "" for i in range(len(operandsName)): expr += operandsName[i] + " " if i < len(operator): expr += operator[i] + " " - return indent + returnedType + " " + returnedValue + equals + expr + ends + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register -def _(op: GetOp): +def _(op: GetOp) -> str: returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - index = op.attributes["index"].value.data - return ( - indent - + returnedType - + " " - + returnedValue - + equals - + op.operands[0].name_hint - + operNameToCpp[op.name].format(index) - + ends - ) + returnedValue = get_ret_val(op) + index = op.attributes["index"].value.data # type: ignore + + if use_llvm_kb: + return ( + IDNT + + returnedType + + " " + + returnedValue + + EQ + + get_operand(op, 0) + + (".Zero" if index == 0 else ".One") + + END + ) + + else: + return ( + IDNT + + returnedType + + " " + + returnedValue + + EQ + + get_operand(op, 0) + + get_op_str(op).format(index) # type: ignore + + END + ) @lowerOperation.register -def _(op: MakeOp): +def _(op: MakeOp) -> str: + returnedValue = get_ret_val(op) + + if use_llvm_kb and isinstance(op.results[0].type, AbstractValueType): + s = f"{IDNT}llvm::KnownBits {returnedValue}{END}" + s += f"{IDNT}{returnedValue}.Zero = {get_operand(op, 0)}{END}" + s += f"{IDNT}{returnedValue}.One = {get_operand(op, 1)}{END}" + return s + returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" expr = "" if len(op.operands) > 0: - expr += op.operands[0].name_hint + expr += get_operand(op, 0) for i in range(1, len(op.operands)): - expr += "," + op.operands[i].name_hint + expr += "," + get_operand(op, i) + return ( - indent + IDNT + returnedType + " " + returnedValue - + equals + + EQ + returnedType - + operNameToCpp[op.name].format(expr) - + ends + + get_op_str(op).format(expr) + + END ) +def trivial_overflow_predicate(op: Operation) -> str: + returnedValue = get_ret_val(op) + varDecls = "bool " + returnedValue + END + expr = get_operand(op, 0) + get_op_str(op) + "(" + expr += get_operand(op, 1) + "," + returnedValue + ")" + result = varDecls + IDNT + expr + END + return IDNT + result + + @lowerOperation.register def _(op: UMulOverflowOp): - varDecls = "bool " + op.results[0].name_hint + ends - expr = op.operands[0].name_hint + operNameToCpp[op.name] + "(" - expr += op.operands[1].name_hint + "," + op.results[0].name_hint - expr += ")" - result = varDecls + "\t" + expr + ends - return indent + result + return trivial_overflow_predicate(op) @lowerOperation.register -def _(op: NegOp): - returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - return ( - indent - + returnedType - + " " - + returnedValue - + equals - + operNameToCpp[op.name] - + op.operands[0].name_hint - + ends - ) +def _(op: SMulOverflowOp): + return trivial_overflow_predicate(op) @lowerOperation.register -def _(op: Return): - opName = operNameToCpp[op.name] + " " - operand = op.arguments[0].name_hint - return indent + opName + operand + ends +def _(op: UAddOverflowOp): + return trivial_overflow_predicate(op) -""" @lowerOperation.register -def _(op: FromArithOp): - opTy = op.op.type - assert isinstance(opTy, IntegerType) - size = opTy.width.data - returnedType = "APInt" - returnedValue = op.results[0].name_hint - return ( - indent - + returnedType - + " " - + returnedValue - + "(" - + str(size) - + ", " - + op.op.name_hint - + ")" - + ends - ) -""" +def _(op: SAddOverflowOp): + return trivial_overflow_predicate(op) @lowerOperation.register -def _(op: arith.Constant): - value = op.value.value.data +def _(op: SShlOverflowOp): + return trivial_overflow_predicate(op) + + +@lowerOperation.register +def _(op: UShlOverflowOp): + return trivial_overflow_predicate(op) + + +@lowerOperation.register +def _(op: NegOp) -> str: + ret_type = lowerType(op.results[0].type) + ret_val = get_ret_val(op) + op_str = get_op_str(op) + operand = get_operand(op, 0) + + return IDNT + ret_type + " " + ret_val + EQ + op_str + operand + END + + +@lowerOperation.register +def _(op: ReturnOp) -> str: + opName = get_op_str(op) + " " + operand = op.arguments[0].name_hint + assert operand + + return IDNT + opName + operand + END + + +@lowerOperation.register +def _(op: arith.ConstantOp): + value = op.value.value.data # type: ignore + assert isinstance(value, int) or isinstance(value, float) assert isinstance(op.results[0].type, IntegerType) size = op.results[0].type.width.data + max_val_plus_one = 1 << size returnedType = "int" - if value > ((1 << 31) - 1): + if value >= (1 << 31): assert False and "arith constant overflow maximal int" - returnedValue = op.results[0].name_hint - return indent + returnedType + " " + returnedValue + " = " + str(value) + ends + returnedValue = get_ret_val(op) + return ( + IDNT + + returnedType + + " " + + returnedValue + + EQ + + str((value + max_val_plus_one) % max_val_plus_one) + + END + ) @lowerOperation.register def _(op: Constant): value = op.value.value.data returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint + returnedValue = get_ret_val(op) return ( - indent + IDNT + returnedType + " " + returnedValue + "(" - + op.operands[0].name_hint + + get_operand(op, 0) + ".getBitWidth()," + str(value) + ")" - + ends + + END ) @lowerOperation.register def _(op: GetAllOnesOp): + ret_type = lowerType(op.results[0].type) + ret_val = get_ret_val(op) + op_name = get_op_str(op) + + return ( + IDNT + + ret_type + + " " + + ret_val + + EQ + + op_name + + "(" + + get_operand(op, 0) + + ".getBitWidth()" + + ")" + + END + ) + + +@lowerOperation.register +def _(op: GetSignedMaxValueOp): + ret_type = lowerType(op.results[0].type) + ret_val = get_ret_val(op) + op_name = get_op_str(op) + + return ( + IDNT + + ret_type + + " " + + ret_val + + EQ + + op_name + + "(" + + get_operand(op, 0) + + ".getBitWidth()" + + ")" + + END + ) + + +@lowerOperation.register +def _(op: GetSignedMinValueOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - opName = operNameToCpp[op.name] + returnedValue = get_ret_val(op) + op_name = get_op_str(op) + return ( - indent + IDNT + returnedType + " " + returnedValue - + " = " - + opName + + EQ + + op_name + "(" - + op.operands[0].name_hint + + get_operand(op, 0) + ".getBitWidth()" + ")" - + ends + + END ) @lowerOperation.register -def _(op: Call): +def _(op: CallOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint + returnedValue = get_ret_val(op) callee = op.callee.string_value() + "(" - operandsName = [oper.name_hint for oper in op.operands] + operandsName = get_op_names(op) expr = "" if len(operandsName) > 0: expr += operandsName[0] for i in range(1, len(operandsName)): expr += "," + operandsName[i] expr += ")" - return indent + returnedType + " " + returnedValue + "=" + callee + expr + ends + return IDNT + returnedType + " " + returnedValue + EQ + callee + expr + END + + +def set_clear_bits( + op: SetHighBitsOp | SetLowBitsOp | ClearHighBitsOp | ClearLowBitsOp, +) -> str: + ret_ty = lowerType(op.results[0].type, op) + ret_val = get_ret_val(op) + arg = get_operand(op, 0) + count = get_operand(op, 1) + op_str = get_op_str(op) + + set_val = f"{IDNT}{ret_ty} {ret_val} = {arg};\n" + cond = f"{count}.ule({count}.getBitWidth())" + if_br = f"{IDNT}{IDNT}{ret_val}{op_str}({count}.getZExtValue());\n" + el_br = f"{IDNT}{IDNT}{ret_val}{op_str}({count}.getBitWidth());\n" + + return f"{set_val}{IDNT}if ({cond})\n{if_br}{IDNT}else\n{el_br}" @lowerOperation.register def _(op: FuncOp): - def lowerArgs(arg): - return lowerType(arg.type) + " " + arg.name_hint + def lowerArgs(arg: BlockArgument) -> str: + global use_apint + assert arg.name_hint + s = f"{lowerType(arg.type)} {arg.name_hint}" + if ( + isinstance(arg.type, AbstractValueType) + or isinstance(arg.type, TupleType) + or isinstance(arg.type, TransIntegerType) + or (isinstance(arg.type, IntegerType) and use_apint) + ): + s = f"{lowerType(arg.type)} &{arg.name_hint}" + + return s returnedType = lowerType(op.function_type.outputs.data[0]) funcName = op.sym_name.data @@ -535,21 +676,23 @@ def lowerArgs(arg): for i in range(1, len(op.args)): expr += "," + lowerArgs(op.args[i]) expr += ")" - # return returnedType + " " + funcName + expr + "{{\n{0}}}\n\n" - return returnedType + " " + funcName + expr + "{\n" + return returnedType + " " + funcName + expr + "{\n" # } -def castToAPIntFromUnsigned(op: Operation): - lastReturn = op.results[0].name_hint + "_autocast" + +def castToAPIntFromUnsigned(op: Operation) -> str: + returnedValue = get_ret_val(op) + lastReturn = returnedValue + "_autocast" apInt = None for operand in op.operands: if isinstance(operand.type, TransIntegerType): apInt = operand.name_hint break returnedType = "APInt" - returnedValue = op.results[0].name_hint + assert apInt + return ( - indent + IDNT + returnedType + " " + returnedValue @@ -558,10 +701,30 @@ def castToAPIntFromUnsigned(op: Operation): + ".getBitWidth()," + lastReturn + ")" - + ends + + END ) +@lowerOperation.register +def _(op: SDivOp): + return lowerToClassMethod(op, None, None) + + +@lowerOperation.register +def _(op: UDivOp): + return lowerToClassMethod(op, None, None) + + +@lowerOperation.register +def _(op: SRemOp): + return lowerToClassMethod(op, None, None) + + +@lowerOperation.register +def _(op: URemOp): + return lowerToClassMethod(op, None, None) + + @lowerOperation.register def _(op: IntersectsOp): return lowerToClassMethod(op, None, None) @@ -587,38 +750,46 @@ def _(op: CountRZeroOp): return lowerToClassMethod(op, None, castToAPIntFromUnsigned) -def castToUnisgnedFromAPInt(operand): - if isinstance(operand.type, TransIntegerType): - return operand.name_hint + ".getZExtValue()" - return operand.name_hint +def castToUnisgnedFromAPInt(operand: SSAValue | str) -> str: + if isinstance(operand, str): + return "(" + operand + ").getZExtValue()" + elif isinstance(operand.type, TransIntegerType): + return f"{operand.name_hint}.getZExtValue()" + + return str(operand.name_hint) @lowerOperation.register -def _(op: SetHighBitsOp): +def _(op: SetHighBitsOp | SetLowBitsOp | ClearHighBitsOp | ClearLowBitsOp): + return set_clear_bits(op) + + +@lowerOperation.register +def _(op: SetSignBitOp): returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" + op.operands[0].name_hint + ends + "\t" - expr = op.results[0].name_hint + operNameToCpp[op.name] + "(" - operands = op.operands[1].name_hint + ".getZExtValue()" + returnedValue = get_ret_val(op) + equals = EQ + get_operand(op, 0) + END + IDNT + expr = returnedValue + get_op_str(op) + "(" + operands = "" expr = expr + operands + ")" - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + + return IDNT + returnedType + " " + returnedValue + equals + expr + END @lowerOperation.register -def _(op: SetLowBitsOp): +def _(op: ClearSignBitOp): returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" + op.operands[0].name_hint + ends + "\t" - expr = op.results[0].name_hint + operNameToCpp[op.name] + "(" - operands = op.operands[1].name_hint + ".getZExtValue()" + returnedValue = get_ret_val(op) + equals = EQ + get_operand(op, 0) + END + IDNT + expr = returnedValue + get_op_str(op) + "(" + operands = "" expr = expr + operands + ")" - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + + return IDNT + returnedType + " " + returnedValue + equals + expr + END @lowerOperation.register -def _(op: GetLowBitsOp): +def _(op: GetLowBitsOp | GetHighBitsOp): return lowerToClassMethod(op, castToUnisgnedFromAPInt) @@ -627,112 +798,24 @@ def _(op: GetBitWidthOp): return lowerToClassMethod(op, None, castToAPIntFromUnsigned) -# op1 < op2? op1: op2 @lowerOperation.register -def _(op: SMaxOp): - returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result +def _(op: SMaxOp | SMinOp | UMaxOp | UMinOp): + return lower_min_max(op) -@lowerOperation.register -def _(op: SMinOp): +def lower_min_max(op: UMinOp | UMaxOp | SMinOp | SMaxOp) -> str: returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + returnedValue = get_ret_val(op) + operands = get_op_names(op) + operator = get_op_str(op) + expr = operator + "(" + operands[0] + "," + operands[1] + ")" -@lowerOperation.register -def _(op: UMaxOp): - returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register -def _(op: UMinOp): - returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result - - -@lowerOperation.register -def _(op: ShlOp): - return lowerToClassMethod(op, castToUnisgnedFromAPInt) - - -@lowerOperation.register -def _(op: AShrOp): - return lowerToClassMethod(op, castToUnisgnedFromAPInt) - - -@lowerOperation.register -def _(op: LShrOp): - return lowerToClassMethod(op, castToUnisgnedFromAPInt) - - -@lowerOperation.register -def _(op: ExtractOp): +def _(op: ShlOp | AShrOp | LShrOp | ExtractOp): return lowerToClassMethod(op, castToUnisgnedFromAPInt) @@ -751,111 +834,91 @@ def _(op: ConstRangeForOp): indvar, *block_iter_args = loopBody.args iter_args = op.iter_args - global indent loopBefore = "" for i, blk_arg in enumerate(block_iter_args): iter_type = lowerType(iter_args[i].type, iter_args[i].owner) iter_name = blk_arg.name_hint - loopBefore += ( - indent + iter_type + " " + iter_name + " = " + iter_args[i].name_hint + ends - ) + iter_arg = iter_args[i].name_hint + assert iter_name + assert iter_arg + + loopBefore += IDNT + iter_type + " " + iter_name + EQ + iter_arg + END - loopFor = indent + "for(APInt {0} = {1}; {0}.ule({2}); {0}+={3}){{\n".format( + loopFor = IDNT + "for(APInt {0} = {1}; {0}.ule({2}); {0}+={3}){{\n".format( indvar.name_hint, lowerBound, upperBound, step ) - indent += "\t" - """ - mainLoop="" - for loopOp in loopBody.ops: - mainLoop+=(indent + indent+ lowerOperation(loopOp)) - endLoopFor=indent+"}\n" - """ + return loopBefore + loopFor @lowerOperation.register -def _(op: NextLoopOp): +def _(op: NextLoopOp) -> str: loopBlock = op.parent_block() - indvar, *block_iter_args = loopBlock.args - global indent + assert loopBlock + _, *block_iter_args = loopBlock.args assignments = "" for i, arg in enumerate(op.operands): - assignments += ( - indent + block_iter_args[i].name_hint + " = " + arg.name_hint + ends - ) - indent = indent[:-1] - endLoopFor = indent + "}\n" + block_arg = block_iter_args[i].name_hint + arg_name = arg.name_hint + assert block_arg + assert arg_name + + assignments += IDNT + block_arg + EQ + arg_name + END + + endLoopFor = IDNT + "}\n" loopOp = loopBlock.parent_op() + assert loopOp + for i, res in enumerate(loopOp.results): - endLoopFor += ( - indent - + lowerType(res.type, loopOp) - + " " - + res.name_hint - + " = " - + block_iter_args[i].name_hint - + ends - ) + ty = lowerType(res.type, loopOp) + res_name = res.name_hint + block_arg = block_iter_args[i].name_hint + assert res_name + assert block_arg + + endLoopFor += IDNT + ty + " " + res_name + EQ + block_arg + END + return assignments + endLoopFor @lowerOperation.register def _(op: RepeatOp): returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - arg0_name = op.operands[0].name_hint - count = op.operands[1].name_hint - initExpr = indent + returnedType + " " + returnedValue + " = " + arg0_name + ends + returnedValue = get_ret_val(op) + arg0_name = get_operand(op, 0) + count = get_operand(op, 1) + initExpr = IDNT + returnedType + " " + returnedValue + EQ + arg0_name + END forHead = ( - indent - + "for(APInt i(" - + count - + ".getBitWidth(),1);i.ult(" - + count - + ");++i){\n" + IDNT + "for(APInt i(" + count + ".getBitWidth(),1);i.ult(" + count + ");++i){\n" ) forBody = ( - indent - + "\t" + IDNT + + IDNT + returnedValue - + " = " + + EQ + returnedValue + ".concat(" + arg0_name + ")" - + ends + + END ) - forEnd = indent + "}\n" + forEnd = IDNT + "}\n" return initExpr + forHead + forBody + forEnd @lowerOperation.register def _(op: AddPoisonOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - opName = operNameToCpp[op.name] - return ( - indent - + returnedType - + " " - + returnedValue - + " = " - + op.operands[0].name_hint - + ends - ) + returnedValue = get_ret_val(op) + operand = get_operand(op, 0) + + return IDNT + returnedType + " " + returnedValue + EQ + operand + END @lowerOperation.register -def _(op: RemovePoisonOp): +def _(op: RemovePoisonOp) -> str: returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - opName = operNameToCpp[op.name] - return ( - indent - + returnedType - + " " - + returnedValue - + " = " - + op.operands[0].name_hint - + ends - ) + returnedValue = get_ret_val(op) + operand = get_operand(op, 0) + + return IDNT + returnedType + " " + returnedValue + EQ + operand + END