Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions Test/Interpreter/LLVM/bitcast.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// RUN: veir-interpret %s | filecheck %s

"builtin.module"() ({
"func.func"() <{sym_name = "main", function_type = () -> (!llvm.byte<64>, i64, i64)}> ({
%3 = "llvm.mlir.poison"() : () -> i64
%4 = "llvm.mlir.constant"() <{ "value" = 32 : i64 }> : () -> i64
%5 = "llvm.bitcast"(%3) : (i64) -> !llvm.byte<64>
%6 = "llvm.lshr"(%5, %4) : (!llvm.byte<64>, i64) -> !llvm.byte<64>
%7 = "llvm.lshr"(%6, %4) : (!llvm.byte<64>, i64) -> !llvm.byte<64>
%8 = "llvm.bitcast"(%6) : (!llvm.byte<64>) -> i64
%9 = "llvm.bitcast"(%7) : (!llvm.byte<64>) -> i64
"func.return"(%6, %8, %9) : (!llvm.byte<64>, i64, i64) -> ()
}) : () -> ()
}) : () -> ()

// CHECK: Program output: #[0b00000000000000000000000000000000????????????????????????????????#64, poison, 0x0000000000000000#64]
27 changes: 27 additions & 0 deletions Test/Interpreter/LLVM/byte.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// RUN: veir-interpret %s | filecheck %s

"builtin.module"() ({
"func.func"() <{sym_name = "main", function_type = () -> (!llvm.byte<64>, i32, i32, !llvm.byte<64>, !llvm.byte<64>)}> ({
%1 = "llvm.mlir.constant"() <{ "value" = 1 : i64 }> : () -> i64
%2 = "llvm.mlir.constant"() <{ "value" = 17 : i64 }> : () -> i64
%3 = "llvm.mlir.poison"() : () -> i8
%4 = "llvm.alloca"(%1) <{ "elem_type" = i64 }> : (i64) -> !llvm.ptr
%5 = "llvm.mlir.constant"() <{ "value" = 4 : i64 }> : () -> i64
%6 = "llvm.getelementptr"(%4, %5) <{elem_type = i8, rawConstantIndices = array<i32: -2147483648>}> : (!llvm.ptr, i64) -> !llvm.ptr
"llvm.store"(%2, %4) : (i64, !llvm.ptr) -> ()
"llvm.store"(%3, %6) : (i8, !llvm.ptr) -> ()
%7 = "llvm.load"(%4) : (!llvm.ptr) -> !llvm.byte<64>
%8 = "llvm.mlir.constant"() <{ "value" = 32 : i64 }> : () -> i64
%9 = "llvm.lshr"(%7, %8) : (!llvm.byte<64>, i64) -> !llvm.byte<64>
%10 = "llvm.trunc"(%9) : (!llvm.byte<64>) -> !llvm.byte<32>
%11 = "llvm.trunc"(%7) : (!llvm.byte<64>) -> !llvm.byte<32>
%12 = "llvm.bitcast"(%10) : (!llvm.byte<32>) -> i32
%13 = "llvm.bitcast"(%11) : (!llvm.byte<32>) -> i32
%14 = "llvm.mlir.constant"() <{ "value" = 4 : i64 }> : () -> i64
%15 = "llvm.shl"(%7, %14) : (!llvm.byte<64>, i64) -> !llvm.byte<64>
%16 = "llvm.freeze"(%7) : (!llvm.byte<64>) -> !llvm.byte<64>
"func.return"(%7, %12, %13, %15, %16) : (!llvm.byte<64>, i32, i32, !llvm.byte<64>, !llvm.byte<64>) -> ()
}) : () -> ()
}) : () -> ()

// CHECK: Program output: #[0b000000000000000000000000????????00000000000000000000000000010001#64, poison, 0x00000011#32, 0b00000000000000000000????????000000000000000000000000000100010000#64, 0b0000000000000000000000000000000000000000000000000000000000010001#64]
29 changes: 25 additions & 4 deletions Veir/Data/LLVM/Byte/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,37 @@ def trunc (x : Byte w) (w' : Nat) : Byte w' :=
simp [←BitVec.setWidth_and, x.h]
)⟩

def shl {w : Nat} (x : Byte w) (y : Int w) (nuw : Bool := false) : Byte w := Id.run do
let .val y' := y | allPoison

if y' ≥ w then
return allPoison

if nuw ∧ (x.val <<< y') >>> y' ≠ x.val then
return allPoison

if nuw ∧ (x.poison <<< y') >>> y' ≠ x.poison then
return allPoison

⟨x.val <<< y', x.poison <<< y', by simp [←BitVec.shiftLeft_and_distrib, x.h]⟩

@[veir_bv_normalize]
def lshr (x : Byte w) (y : Int w) : Byte w :=
if y.isPoison || y.getValueD ≥ w then
def lshr (x : Byte w) (y : Int w) (exact := false) : Byte w :=
let y' := y.getValueD
if y.isPoison || y' ≥ w then
allPoison
else if exact ∧ (x.val >>> y') <<< y' ≠ x.val then
allPoison
else if exact ∧ (x.poison >>> y') <<< y' ≠ x.poison then
allPoison
else
let y := y.getValueD
⟨x.val >>> y, x.poison >>> y, by (
⟨x.val >>> y', x.poison >>> y', by (
simp [←BitVec.ushiftRight_and_distrib, x.h]
)⟩

def freeze (x : Byte w) : Byte w :=
⟨x.val, 0#w, by grind⟩

def toString_rec {w : Nat} (b : Byte w) : String :=
if w = 0 then "" else
s!"{if b.poison.getMsbD 0 then "?" else ToString.toString (b.val.getMsbD 0).toNat}{(b.trunc (w - 1)).toString_rec}"
Expand Down
115 changes: 88 additions & 27 deletions Veir/Interpreter/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -359,17 +359,16 @@ def MemoryState.alloc (state : MemoryState) (size : UInt64)

/--
Store raw bytes to the given address in memory,
and unset the corresponding poison bits.
and set the corresponding poison bits as requested (by default, unset).
Yields UB if the access is out of bounds.
-/
def MemoryState.store (state : MemoryState) (addr : UInt64) (val : ByteArray)
(poison : ByteArray := ByteArray.replicate val.size 0) (h : poison.size = val.size := by grind)
: Interp MemoryState :=
if addr.toNat + val.size ≤ state.contents.size then
let poison := ByteArray.replicate val.size 0
return ⟨val.copySlice 0 state.contents addr.toNat val.size false,
poison.copySlice 0 state.poisonMask addr.toNat val.size false,
by
have h : poison.size = val.size := by grind
simp [ByteArray.copySlice_eq_append, state.consistentSize, h]
else
Expand Down Expand Up @@ -407,6 +406,7 @@ def MemoryState.llvmStore (state : MemoryState) (addr : UInt64) (val : RuntimeVa
| .int 16 (.val v) => state.store addr (UInt16.ofBitVec v).toByteArrayLE
| .int 32 (.val v) => state.store addr (UInt32.ofBitVec v).toByteArrayLE
| .int 64 (.val v) => state.store addr (UInt64.ofBitVec v).toByteArrayLE
| .byte 64 v => state.store addr (UInt64.ofBitVec v.val).toByteArrayLE (UInt64.ofBitVec v.poison).toByteArrayLE (by simp)
| .int n .poison => state.empoison addr (n / 8)
| .addr v => state.store addr v.toByteArrayLE
| _ => none
Expand All @@ -423,21 +423,31 @@ def MemoryState.load (state : MemoryState) (addr size : UInt64)
Interp.ub

/--
Check if any of the `size` bytes at the given memory address `addr` is poison.
Load bitwise poison status of the given memory address.
Yields UB if the access is out of bounds.
-/
def MemoryState.hasPoison (state : MemoryState) (addr size : UInt64)
: Interp Bool :=
if addr.toNat + size.toNat <= state.contents.size then do
let mut poison := false
for b in state.poisonMask.extract addr.toNat (addr + size).toNat do
if b ≠ 0 then
poison := true
break
return poison
def MemoryState.loadPoison (state : MemoryState) (addr size : UInt64)
: Interp ByteArray :=
if addr.toNat + size.toNat <= state.poisonMask.size then
return state.poisonMask.extract addr.toNat (addr + size).toNat
else
Interp.ub

/--
Check if any of the `size` bytes at the given memory address `addr` is poison.
Yields UB if the access is out of bounds.
-/
def MemoryState.hasPoison (state : MemoryState) (addr size : UInt64)
: Interp Bool := do
let poisonMask ← state.loadPoison addr size
let mut poison := false
for b in poisonMask do
if b ≠ 0 then
poison := true
break
return poison

set_option warn.sorry false in
/--
Load an LLVM value from the given memory address.
Yields UB if access is out of bounds or the address is 0.
Expand All @@ -462,6 +472,11 @@ def MemoryState.llvmLoad (state : MemoryState) (addr : UInt64) (type : TypeAttr)
let ba ← state.load addr 8
if ← state.hasPoison addr 8 then return .int 64 .poison
return .int 64 (.val (BitVec.ofNat 64 ba.toUInt64LE!.toNat))
| Attribute.byteType { bitwidth := 64 } =>
let ba ← state.load addr 8
let baPoison ← state.loadPoison addr 8
let poison := baPoison.toUInt64LE!.toBitVec
return .byte 64 ⟨ba.toUInt64LE!.toBitVec &&& ~~~poison, poison, by bv_decide⟩
| Attribute.llvmPointerType _ =>
let ba ← state.load addr 8
-- FIXME poison address
Expand Down Expand Up @@ -695,15 +710,30 @@ def Llvm.interpretOp' (opType : Veir.Llvm) (properties : HasDialectOpInfo.proper
if v' = 0 then Interp.ub
else return (#[.int bw (LLVM.Int.urem lhs rhs)], mem, none)
| .shl => do
let [.int bw lhs, .int bw' rhs] := operands.toList | none
if h: bw' ≠ bw then none else
let rhs := rhs.cast (by simp at h; exact h)
return (#[.int bw (LLVM.Int.shl lhs rhs properties.nsw properties.nuw)], mem, none)
let [lhs, .int bw' rhs] := operands.toList | none
match lhs with
| .int bw lhs =>
if h: bw' ≠ bw then none else
let rhs := rhs.cast (by simp at h; exact h)
return (#[.int bw (LLVM.Int.shl lhs rhs properties.nsw properties.nuw)], mem, none)
| .byte bw lhs =>
if h: bw' ≠ bw then none else
if properties.nsw then none else
let rhs := rhs.cast (by simp at h; exact h)
return (#[.byte bw (LLVM.Byte.shl lhs rhs properties.nuw)], mem, none)
| _ => none
| .lshr => do
let [.int bw lhs, .int bw' rhs] := operands.toList | none
if h: bw' ≠ bw then none else
let rhs := rhs.cast (by simp at h; exact h)
return (#[.int bw (LLVM.Int.lshr lhs rhs properties.exact)], mem, none)
let [lhs, .int bw' rhs] := operands.toList | none
match lhs with
| .int bw lhs =>
if h: bw' ≠ bw then none else
let rhs := rhs.cast (by simp at h; exact h)
return (#[.int bw (LLVM.Int.lshr lhs rhs properties.exact)], mem, none)
| .byte bw lhs =>
if h: bw' ≠ bw then none else
let rhs := rhs.cast (by simp at h; exact h)
return (#[.byte bw (LLVM.Byte.lshr lhs rhs properties.exact)], mem, none)
| _ => none
| .ashr => do
let [.int bw lhs, .int bw' rhs] := operands.toList | none
if h: bw' ≠ bw then none else
Expand Down Expand Up @@ -774,11 +804,18 @@ def Llvm.interpretOp' (opType : Veir.Llvm) (properties : HasDialectOpInfo.proper
let rhs := rhs.cast (by simp at h; exact h)
return (#[.int bw (LLVM.Int.umin lhs rhs)], mem, none)
| .trunc => do
let [.int w val] := operands.toList | none
let [val] := operands.toList | none
let some resType := resultTypes[0]? | none
let .integerType resBw := resType.val | none
if h: resBw.bitwidth >= w then none else
return (#[.int resBw.bitwidth (LLVM.Int.trunc val resBw.bitwidth properties.nsw properties.nuw (by omega))], mem, none)
match val with
| .int w val =>
let .integerType resBw := resType.val | none
if h: resBw.bitwidth >= w then none else
return (#[.int resBw.bitwidth (LLVM.Int.trunc val resBw.bitwidth properties.nsw properties.nuw (by omega))], mem, none)
| .byte w val =>
let .byteType resBw := resType.val | none
if h: resBw.bitwidth >= w then none else
return (#[.byte resBw.bitwidth (LLVM.Byte.trunc val resBw.bitwidth)], mem, none)
| _ => none
| .zext => do
let [.int w val] := operands.toList | none
let some resType := resultTypes[0]? | none
Expand Down Expand Up @@ -847,8 +884,32 @@ def Llvm.interpretOp' (opType : Veir.Llvm) (properties : HasDialectOpInfo.proper
| .val idx => return (#[.addr (ptr.toNat + idx.toNat * size).toUInt64], mem, none)
| .poison => Interp.ub
| .freeze => do
let [RuntimeValue.int w val] := operands.toList | none
return (#[RuntimeValue.int w (LLVM.Int.freeze val)], mem, none)
let [val] := operands.toList | none
match val with
| .int w val =>
return (#[.int w val.freeze], mem, none)
| .byte w val =>
return (#[.byte w val.freeze], mem, none)
| _ => none
| .bitcast => do
let [val] := operands.toList | none
let [⟨type, _⟩] := resultTypes.toList | none
let result ← do match val, type with
| .int bw1 val', .integerType ⟨bw2⟩ =>
if bw1 ≠ bw2 then none else some (.ok val)
| .int bw1 val', .byteType ⟨bw2⟩ =>
if bw1 ≠ bw2 then none else some (.ok (.byte bw1 $ LLVM.Byte.fromInt val'))
| .byte bw1 val', .byteType ⟨bw2⟩ =>
if bw1 ≠ bw2 then none else some (.ok val)
| .byte bw1 val', .integerType ⟨bw2⟩ =>
if bw1 ≠ bw2 then none else some (.ok (.int bw1 $ val'.toInt))
| .byte bw val', .llvmPointerType _ =>
if h : bw = 64 then some (.ok (.addr (val'.cast h).toUInt64)) else none
| .addr val', .llvmPointerType _ => some (.ok val)
| .addr val', .byteType ⟨bw⟩ =>
if h : bw = 64 then some (.ok (.byte 64 $ LLVM.Byte.fromUInt64 val')) else none
| _, _ => none
return (#[result], mem, none)
| _ => none

/-- Effective address of a RISC-V load/store: the base register value plus the
Expand Down
13 changes: 4 additions & 9 deletions Veir/Verifier.lean
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def OperationPtr.verifyIntegerUnop (op : OperationPtr) (ctx : WfIRContext OpCode
op.verifyResultTypeMatches ctx operandType s!"{instrName}: Expected result type to match operand type"
pure operandType

def OperationPtr.verifyLLVMLshr (op : OperationPtr) (ctx : WfIRContext OpCode)
def OperationPtr.verifyLLVMShift (op : OperationPtr) (ctx : WfIRContext OpCode)
(opIn : op.InBounds ctx.raw) : Except String PUnit := do
op.verifyPlainOpCounts ctx opIn 2 1
let instrName := String.fromUTF8! (op.getOpType ctx.raw opIn).name
Expand Down Expand Up @@ -523,13 +523,13 @@ def OperationPtr.verifyLocalInvariants (op : OperationPtr) (ctx : WfIRContext Op
op.verifyPlainOpCounts ctx opIn 0 1
pure ()
| .llvm .and | .llvm .or | .llvm .xor | .llvm .intr__smax | .llvm .intr__smin
| .llvm .intr__umax | .llvm .intr__umin | .llvm .add | .llvm .sub | .llvm .shl
| .llvm .intr__umax | .llvm .intr__umin | .llvm .add | .llvm .sub
| .llvm .ashr | .llvm .mul | .llvm .sdiv | .llvm .udiv
| .llvm .srem | .llvm .urem => do
op.verifyIntegerBinop ctx opIn
pure ()
| .llvm .lshr => do
op.verifyLLVMLshr ctx opIn
| .llvm .lshr | .llvm .shl => do
op.verifyLLVMShift ctx opIn
pure ()
| .llvm .intr__fshl | .llvm .intr__fshr => do
op.verifyIntegerTernop ctx opIn
Expand Down Expand Up @@ -1178,11 +1178,6 @@ theorem OperationPtr.Verified.llvm_sub {op : OperationPtr} {opInBounds}
op.IsVerifiedIntegerBinop ctx := OperationPtr.Verified.integerBinop opVerify <| by
simp only [verifyLocalInvariants, ← getOpType!_eq_getOpType, opType]

theorem OperationPtr.Verified.llvm_shl {op : OperationPtr} {opInBounds}
(opVerify : op.Verified ctx opInBounds) (opType : op.getOpType! ctx.raw = .llvm .shl) :
op.IsVerifiedIntegerBinop ctx := OperationPtr.Verified.integerBinop opVerify <| by
simp only [verifyLocalInvariants, ← getOpType!_eq_getOpType, opType]

theorem OperationPtr.Verified.llvm_ashr {op : OperationPtr} {opInBounds}
(opVerify : op.Verified ctx opInBounds) (opType : op.getOpType! ctx.raw = .llvm .ashr) :
op.IsVerifiedIntegerBinop ctx := OperationPtr.Verified.integerBinop opVerify <| by
Expand Down
Loading