Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 97 additions & 102 deletions Test/Passes/InstructionSelection/RISCV64/fastntt.mlir

Large diffs are not rendered by default.

82 changes: 82 additions & 0 deletions Test/Passes/InstructionSelection/RISCV64/icmp_imm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// RUN: veir-opt %s -p=isel-sdag-riscv64 | filecheck %s

// Immediate-form comparison selection: when the rhs is a constant fitting a
// signed 12-bit immediate, `icmp` lowers directly to `slti`/`sltiu` (with an
// `xori _ 1` inversion for `>=`, and the `x <= C == x < C+1` off-by-one for
// `<=`). Verified sound by `icmp_refinement_*_imm` in Proofs.lean. The `>`
// predicates (sgt/ugt) are intentionally left to the reg-reg lowering (same
// cost, and better for `> 0` via x0), so they are not exercised here.

"builtin.module"() ({
// sge: a >=s 5 -> xori (slti a 5) 1
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = 5 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 5 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "riscv.slti"(%{{.*}}) <{"value" = 5 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK: "riscv.xori"(%{{.*}}) <{"value" = 1 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK-NOT: "llvm.icmp"
}) : () -> ()

// uge: a >=u 5 -> xori (sltiu a 5) 1
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = 5 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 9 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "riscv.sltiu"(%{{.*}}) <{"value" = 5 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK: "riscv.xori"(%{{.*}}) <{"value" = 1 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK-NOT: "llvm.icmp"
}) : () -> ()

// sle: a <=s 5 -> slti a 6 (x <= C == x < C+1)
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = 5 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 3 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "riscv.slti"(%{{.*}}) <{"value" = 6 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK-NOT: "riscv.xori"
// CHECK-NOT: "llvm.icmp"
}) : () -> ()

// ule: a <=u 5 -> sltiu a 6
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = 5 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 7 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "riscv.sltiu"(%{{.*}}) <{"value" = 6 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK-NOT: "riscv.xori"
// CHECK-NOT: "llvm.icmp"
}) : () -> ()

// Bail: sle against 2047 would need immediate 2048, which does not fit a
// signed 12-bit field, so the peephole defers to the reg-reg lowering and
// the `llvm.icmp` is left for `isel-riscv64` (not run here).
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = 2047 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 3 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "llvm.icmp"
// CHECK-NOT: "riscv.slti"
}) : () -> ()

// Bail: ule against -1 (unsigned UINT_MAX) is excluded, since C+1 wraps to 0.
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = -1 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 7 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "llvm.icmp"
// CHECK-NOT: "riscv.sltiu"
}) : () -> ()
}) : () -> ()
31 changes: 31 additions & 0 deletions Test/Passes/InstructionSelection/RISCV64/icmp_zero.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// RUN: veir-opt %s -p=isel-riscv64 | filecheck %s

// Peephole: comparing against a constant 0 on the rhs lowers `eq`/`ne` directly
// on the non-zero operand, with no `riscv.xor`. Canonicalization runs before
// isel and moves the constant to the rhs, so only the rhs case is handled.
// a == 0 -> riscv.sltiu a 1 (seqz)
// a != 0 -> riscv.sltu 0 a (snez)

"builtin.module"() ({
// eq, zero on the right: a == 0
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = 0 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 0 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "riscv.sltiu"(%{{.*}}) <{"value" = 1 : i64}> : (!riscv.reg) -> !riscv.reg
// CHECK-NOT: "riscv.xor"
}) : () -> ()

// ne, zero on the right: a != 0 -> riscv.sltu 0 a (no xor)
"func.func"() <{function_type = (i64) -> (i1)}> ({
^bb0(%a: i64):
%z = "llvm.mlir.constant"() <{value = 0 : i64}> : () -> i64
%r = "llvm.icmp"(%a, %z) <{predicate = 1 : i64}> : (i64, i64) -> i1
"func.return"(%r) : (i1) -> ()
// CHECK-LABEL: "func.func"
// CHECK: "riscv.sltu"(%{{.*}}, %{{.*}}) : (!riscv.reg, !riscv.reg) -> !riscv.reg
// CHECK-NOT: "riscv.xor"
}) : () -> ()
}) : () -> ()
8 changes: 6 additions & 2 deletions Test/Passes/InstructionSelection/RISCV64/slti.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// RUN: veir-opt %s -p=isel-sdag-riscv64 | filecheck %s

// icmp against a signed-12-bit constant on the right selects to slti / sltiu
// (PatGprSimm12<setlt, SLTI> / PatGprSimm12<setult, SLTIU>). Other predicates
// and out-of-range immediates fall through to the general icmp lowering in
// (PatGprSimm12<setlt, SLTI> / PatGprSimm12<setult, SLTIU>), and the `<=`/`>=`
// predicates select to the same via an `xori _ 1` inversion and/or a `+1`
// off-by-one immediate (see icmp_imm.mlir). The `>` predicates (sgt/ugt) and
// out-of-range immediates fall through to the general icmp lowering in
// isel-riscv64, so here they stay as `llvm.icmp`.

"builtin.module"() ({
Expand All @@ -25,6 +27,8 @@
}) : () -> ()

// icmp sgt: no immediate form (predicate 4 = sgt) -> stays `llvm.icmp`.
// The reg-reg lowering (slt with swapped operands) is already the same
// instruction count, and is strictly better for the `> 0` case via x0.
"func.func"() <{function_type = (i64) -> i1}> ({
^bb(%a: i64):
%c = "llvm.mlir.constant"() <{value = 7 : i64}> : () -> i64
Expand Down
36 changes: 36 additions & 0 deletions Test/Passes/RISCVCombines/li_zero_to_x0.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// RUN: veir-opt %s -p=riscv-combine | filecheck %s

// riscv-combine replaces the result of a `riscv.li 0` with a reference to the
// hard-wired zero register `x0` (`rv64.get_register`), dropping the
// materialization. This is valid because every consumer reads it as a source
// register, and `x0` reads as 0 in any source position -- including when it is
// forwarded as a generic `!riscv.reg` block argument.

"builtin.module"() ({
"func.func"() <{function_type = (!riscv.reg) -> ()}> ({
^bb0(%x: !riscv.reg):
%z = "riscv.li"() <{value = 0 : i64}> : () -> !riscv.reg
%s = "riscv.slt"(%z, %x) : (!riscv.reg, !riscv.reg) -> !riscv.reg
"riscv_cf.branch"(%z) [^bb1] : (!riscv.reg) -> ()
^bb1(%p: !riscv.reg):
"func.return"() : () -> ()
}) : () -> ()

// A non-zero constant must be left untouched.
"func.func"() <{function_type = (!riscv.reg) -> ()}> ({
^bb0(%x: !riscv.reg):
%one = "riscv.li"() <{value = 1 : i64}> : () -> !riscv.reg
%s = "riscv.slt"(%one, %x) : (!riscv.reg, !riscv.reg) -> !riscv.reg
"func.return"() : () -> ()
}) : () -> ()
}) : () -> ()

// The `li 0` is gone, replaced by a single x0 reference used everywhere.
// CHECK: [[X0:%.*]] = "rv64.get_register"() : () -> !riscv.reg<x0>
// CHECK-NOT: "riscv.li"() <{"value" = 0
// CHECK: "riscv.slt"([[X0]], %{{.*}}) : (!riscv.reg<x0>, !riscv.reg) -> !riscv.reg
// CHECK: "riscv_cf.branch"([[X0]]) [^{{.*}}] : (!riscv.reg<x0>) -> ()

// The non-zero `li 1` survives unchanged.
// CHECK: [[ONE:%.*]] = "riscv.li"() <{"value" = 1 : i64}> : () -> !riscv.reg
// CHECK: "riscv.slt"([[ONE]], %{{.*}}) : (!riscv.reg, !riscv.reg) -> !riscv.reg
35 changes: 25 additions & 10 deletions Tools/vsmith
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ DEFAULT_PASSES = "instcombine,dce,cse,dce"
RISCV_PASSES = (
"canonicalize,instcombine,canonicalize,cse,dce,"
"isel-br-riscv64,isel-sdag-riscv64,isel-riscv64,"
"canonicalize,riscv-combine,reconcile-cast,dce"
"canonicalize,reconcile-cast,riscv-combine,dce"
)

# --- RISC-V MIR execution cross-check -------------------------------------
Expand Down Expand Up @@ -506,13 +506,14 @@ def mir_execution_check(repo: Path, opt: Path, after: str, keep: bool) -> tuple[
stem = opt.name[: -len(".opt.mlir")] if opt.name.endswith(".opt.mlir") else opt.stem
d = opt.parent
mir = d / f"{stem}.mir"
lowered = d / f"{stem}.lowered.mir"
asm = d / f"{stem}.s"
func_o = d / f"{stem}.func.o"
drv_s = d / f"{stem}.driver.s"
drv_o = d / f"{stem}.driver.o"
elf = d / f"{stem}.elf"
ld_script = d / "veir_mir_link.ld"
artifacts = [mir, asm, func_o, drv_s, drv_o, elf]
artifacts = [mir, lowered, asm, func_o, drv_s, drv_o, elf]

def cleanup() -> None:
if not keep:
Expand All @@ -528,17 +529,31 @@ def mir_execution_check(repo: Path, opt: Path, after: str, keep: bool) -> tuple[
return True, "" # lowering not supported for this case; skip
mir.write_text(out)

# 2. llc: MIR -> RISC-V assembly. -O0 runs the minimal machine pipeline (no
# MachineCSE/peephole/sched/combiner), so VeIR's instruction selection passes
# through unmodified; -start-after=finalize-isel resumes from exactly the
# post-isel state our MIR represents, keeping the essential pre-RA passes
# (UnreachableMachineBlockElim, ProcessImplicitDefs, RISCVPreRAExpandPseudo).
rc, out = _mir_run([tools["llc"], str(mir), "-mtriple=riscv64",
"-O0", "-start-after=finalize-isel", "-o", str(asm)])
# 2a. llc, register-allocation stage: MIR -> fully-lowered MIR. Rather than a
# canned -O0/-O2 pipeline we drive an explicit minimal machine pipeline via
# -run-pass so VeIR's instruction selection passes through unmodified while we
# still get the greedy allocator: PHI elimination, two-address lowering,
# register coalescing (deletes isel-generated copies), greedy allocation +
# virtual-register rewriting, prologue/epilogue insertion, and post-RA / RISCV
# pseudo expansion. -run-pass emits MIR, so the asm printer runs in stage 2b.
rc, out = _mir_run([tools["llc"], str(mir), "-mtriple=riscv64", "-x", "mir",
"-run-pass=phi-node-elimination,twoaddressinstruction,"
"register-coalescer,greedy,virtregrewriter,prologepilog,"
"postrapseudos,riscv-expand-pseudo", "-o", str(lowered)])
if rc != 0:
return False, (f"llc failed to lower MIR\n"
return False, (f"llc failed to lower MIR (regalloc stage)\n"
f"optimized: {opt}\nmir: {mir}\n{out}")

# 2b. llc, emission stage: fully-lowered MIR -> RISC-V assembly.
# -start-after=riscv-expand-pseudo skips straight to the asm printer (the few
# remaining passes are no-ops here) without re-running regalloc or prologue
# insertion, which would corrupt the already-lowered MIR.
rc, out = _mir_run([tools["llc"], str(lowered), "-mtriple=riscv64", "-x", "mir",
"-start-after=riscv-expand-pseudo", "-o", str(asm)])
if rc != 0:
return False, (f"llc failed to emit assembly from lowered MIR\n"
f"optimized: {opt}\nlowered: {lowered}\n{out}")

# 3. assemble the lowered function and the comparison driver, then link
if not ld_script.exists():
ld_script.write_text(MIR_LINK_LD)
Expand Down
14 changes: 14 additions & 0 deletions Veir/MIRPrinter.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ public section

namespace Veir.MIRPrinter

/-- The physical-register MIR name (e.g. `$x0`) named by a register type
carrying an index, if any. -/
def physRegName? : Attribute → Option String
| .registerType { index := some n } => some s!"$x{n}"
| _ => none

/-- Virtual-register name for a value: `%v<opId>` for op results,
`%arg<blockId>_<i>` for block arguments. -/
def vreg (ctx : IRContext OpCode) (v : ValuePtr) : String :=
Expand Down Expand Up @@ -267,6 +273,14 @@ def emitRegular (ctx : IRContext OpCode) (op : OperationPtr) : IO Unit := do
match regRegMnem rop with
| some m => IO.println s!" {res} = {m} {v 0}, {v 1}"
| none => IO.println s!" ; UNHANDLED {reprStr rop}"
-- `rv64.get_register` references a physical register (e.g. the zero register
-- `$x0`). Copy it into a virtual register so every use -- including PHI
-- operands, which may not be physical registers -- is valid MIR; the register
-- allocator coalesces the copy away, leaving a direct `$x0` use in assembly.
| .rv64 .get_register =>
match (op.getResultTypes! ctx)[0]?.bind (physRegName? ·.val) with
| some name => IO.println s!" {res} = COPY {name}"
| none => IO.println s!" ; UNHANDLED op"
| _ => IO.println s!" ; UNHANDLED op"

/-- Emit a terminator operation (branch / return). `lsuccs` gives the lowered
Expand Down
65 changes: 65 additions & 0 deletions Veir/Passes/InstructionSelection/Proofs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,36 @@ theorem icmp_refinement_ne {x y : LLVM.Int 64} :
(RISCV.Reg.toInt (Data.RISCV.sltu (Data.RISCV.xor (LLVM.Int.toReg y) (LLVM.Int.toReg x)) (Data.RISCV.li 0#64)) 1) := by
veir_bv_decide

/--
Prove the correctness of the constant-zero `icmp eq` peephole, with the zero on
the right (`x == 0`). The lowering drops the `xor` and emits `sltiu x 1` (seqz)
directly on the non-zero operand.
-/
theorem icmp_refinement_eq_zero_rhs {x : LLVM.Int 64} :
(Data.LLVM.Int.icmp x (LLVM.Int.constant 64 0) LLVM.IntPred.eq) ⊒
(RISCV.Reg.toInt (Data.RISCV.sltiu 1#12 (LLVM.Int.toReg x)) 1) := by
veir_bv_decide

/--
Prove the correctness of the constant-zero `icmp ne` peephole, with the zero on
the right (`x != 0`). The lowering drops the `xor` and emits `sltu 0 x` (snez)
directly on the non-zero operand.
-/
theorem icmp_refinement_ne_zero_rhs {x : LLVM.Int 64} :
(Data.LLVM.Int.icmp x (LLVM.Int.constant 64 0) LLVM.IntPred.ne) ⊒
(RISCV.Reg.toInt (Data.RISCV.sltu (LLVM.Int.toReg x) (Data.RISCV.li 0#64)) 1) := by
veir_bv_decide

/--
Prove the correctness of the `riscv-combine` `li 0 -> x0` rewrite: materializing
the constant `0` with `li` produces exactly the value of the hard-wired zero
register `x0` (which the interpreter models as the register holding `0#64`).
Since every consumer is a pure function of its source registers' values,
substituting `x0` for the `li 0` result preserves semantics.
-/
theorem li_zero_eq_x0 :
Data.RISCV.li 0#64 = RISCV.Reg.mk 0#64 := rfl

/--
Prove the correctness of the `icmp` lowering pattern with `slt`.
-/
Expand Down Expand Up @@ -203,6 +233,41 @@ theorem icmp_refinement_uge {x y : LLVM.Int 64} :
(RISCV.Reg.toInt (Data.RISCV.xori 1#12 (Data.RISCV.sltu (LLVM.Int.toReg y) (LLVM.Int.toReg x))) 1) := by
veir_bv_decide

/-! ### Immediate-constant refinements for ordered comparisons

Each theorem justifies one arm of the `slti` immediate-form selection in
`isel-sdag-riscv64`, comparing `x` against a constant that equals the
sign-extension of the 12-bit immediate `imm` actually encoded in the emitted
instruction. For the `≤` predicates the constant is `sext(imm) - 1`, capturing
the `x ≤ C == x < C+1` rewrite (the code stores `C+1` as the immediate). The
unsigned off-by-one form carries the `imm ≠ 0` hypothesis that the code
enforces by rejecting `C = -1` (else `C+1` wraps past `UINT_MAX`). -/

/-- `icmp sge x C` with `C = sext(imm)` -> `xori (slti x imm) 1`. -/
theorem icmp_refinement_sge_imm {x : LLVM.Int 64} (imm : BitVec 12) :
(Data.LLVM.Int.icmp x (LLVM.Int.val (imm.signExtend 64)) LLVM.IntPred.sge) ⊒
(RISCV.Reg.toInt (Data.RISCV.xori 1#12 (Data.RISCV.slti imm (LLVM.Int.toReg x))) 1) := by
veir_bv_decide

/-- `icmp uge x C` with `C = sext(imm)` -> `xori (sltiu x imm) 1`. -/
theorem icmp_refinement_uge_imm {x : LLVM.Int 64} (imm : BitVec 12) :
(Data.LLVM.Int.icmp x (LLVM.Int.val (imm.signExtend 64)) LLVM.IntPred.uge) ⊒
(RISCV.Reg.toInt (Data.RISCV.xori 1#12 (Data.RISCV.sltiu imm (LLVM.Int.toReg x))) 1) := by
veir_bv_decide

/-- `icmp sle x C` with `C = sext(imm) - 1` -> `slti x imm` (i.e. `x < C+1`). -/
theorem icmp_refinement_sle_imm {x : LLVM.Int 64} (imm : BitVec 12) :
(Data.LLVM.Int.icmp x (LLVM.Int.val (imm.signExtend 64 - 1)) LLVM.IntPred.sle) ⊒
(RISCV.Reg.toInt (Data.RISCV.slti imm (LLVM.Int.toReg x)) 1) := by
veir_bv_decide

/-- `icmp ule x C` with `C = sext(imm) - 1` -> `sltiu x imm`; needs `imm ≠ 0`
(else `C = UINT_MAX` and `C+1` wraps to `0`). -/
theorem icmp_refinement_ule_imm {x : LLVM.Int 64} (imm : BitVec 12) (h : imm ≠ 0) :
(Data.LLVM.Int.icmp x (LLVM.Int.val (imm.signExtend 64 - 1)) LLVM.IntPred.ule) ⊒
(RISCV.Reg.toInt (Data.RISCV.sltiu imm (LLVM.Int.toReg x)) 1) := by
veir_bv_decide

/--
Prove the correctness of the `or` lowering pattern.
-/
Expand Down
Loading
Loading