Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions DatapathVerification/BitHeap/BVComb.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import DatapathVerification.BitHeap.BitHeap
import DatapathVerification.BitHeap.Circuit
import DatapathVerification.BitHeap.DaddaTree

open BitHeap
namespace Comb

inductive ArithBinopKind
| add
| mul

-- inductive ArithUnopKind
-- | neg

-- inductive BooleanBinopKind
-- | and | or | xor

inductive ArithCircuit
| var (width : Nat) (varIndex : Nat)
| arithbinop (kind : ArithBinopKind) (width : Nat) (l r : ArithCircuit)
-- | arithunop (kind : ArithUnopKind) (width : Nat) (arg : ArithCircuit)
-- | bvbinop (kind : BooleanBinopKind) (width : Nat) (l r : ArithCircuit)

/--
Convert a bitheap into a new bitheap that has a single row,
by building a Dadda tree of adders to reduce the bitheap to a single row.
-/
def BitHeap.toSingleRow : BitHeap → CircuitVector
| bh =>
let (pp1, pp2) := DaddaTree.DaddaTree bh
-- add pp1 and pp2 to get the final row
sorry

namespace ArithCircuit
/--
Given a bitvector (x : BV 3), but a bitheap
```
* * *
x2 x1 x0
```
-/
def bitheapOfVar (width : Nat) (varIndex : Nat) : BitHeap :=
-- I want to create a bitheap that has one bit-variable per bit in the bitvector variable.
-- | We need to know that this index is unique which is a gigantic pain.
List.range width |>.foldl (fun bh i => bh.addBit i (BitHeap.Circuit.bit (varIndex * width + i))) BitHeap.empty

def toBitHeap (c : ArithCircuit) : BitHeap :=
match c with
| .var width varIndex => bitheapOfVar width varIndex
| .arithbinop kind width l r =>
match kind with
| .add => (toBitHeap l).addBitHeap (toBitHeap r)
| .mul => (toBitHeap l).mulBitHeap (toBitHeap r)
-- | .arithunop kind width arg =>
-- match kind with
-- | .neg => (toBitHeap arg).negBitHeap
-- | .bvbinop kind width l r =>
-- match kind with
-- | .and =>
-- let lRow := (l.toBitHeap).toSingleRow
-- let rRow := (r.toBitHeap).toSingleRow
-- let newRow := Array.zipWith (fun lBit rBit => Circuit.and lBit rBit) lRow rRow
-- BitHeap.fromRow newRow

def toCircuitVector (c : ArithCircuit) : CircuitVector :=
let bh := c.toBitHeap
bh.toSingleRow

end ArithCircuit

end Comb
24 changes: 24 additions & 0 deletions DatapathVerification/BitHeap/BitHeap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,30 @@ def addBit (column : Nat) (c : Circuit) (h : BitHeap w) : BitHeap w :=
h.setColumn column (col.insert c) h1
else addBit (column + 1) c (h.removeBit column c)

-- TODO: make this variable size add
def addBitHeap' (h1 h2 : BitHeap w) : BitHeap w:=
let h := BitHeap.empty w
let h := h1.columns.zipIdx.foldl (fun acc (column, index) =>
column.elems.toList.foldl (fun acc' c => acc'.addBit index c) acc) h
let h := h2.columns.zipIdx.foldl (fun acc (column, index) =>
column.elems.toList.foldl (fun acc' c => acc'.addBit index c) acc) h
h

def addBitHeap (bhs : List (BitHeap w)) : BitHeap w:=
let h := BitHeap.empty w
let h := bhs.foldl (fun acc heap => heap.columns.zipIdx.foldl (fun acc' (column, index) =>
column.elems.toList.foldl (fun acc' c => acc'.addBit index c) acc') acc) h
h

def mulBitHeap (h0 h1 : BitHeap w) : BitHeap (2 * w - 1) :=
let h := BitHeap.empty (2 * w - 1)
let h := h0.columns.zipIdx.foldl (fun acc (column0, i0) =>
h1.columns.zipIdx.foldl (fun acc' (column1, i2) =>
column0.elems.toList.foldl (fun acc'' c1 =>
column1.elems.toList.foldl (fun acc''' c2 =>
acc'''.addBit (i0 + i2) (Circuit.binop .and c1 c2)) acc'') acc') acc) h
h

structure AdderResult (w : Nat) where
heap : BitHeap w
sum : Circuit
Expand Down
12 changes: 12 additions & 0 deletions DatapathVerification/BitHeap/Circuit.lean
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,16 @@ theorem eval_atLeastTwo (a b c : Circuit) (env : BitEnv) :

end Circuit


/-- A vector of circuits, used to represent symbolic BitVectors. -/
def CircuitVector := Array Circuit

namespace CircuitVector

def eval (vec : CircuitVector) (env : Circuit.BitEnv) : Int :=
(vec.mapIdx (fun i c => 2^i * (if c.eval env then 1 else 0))).sum

end CircuitVector


end BitHeap
37 changes: 37 additions & 0 deletions DatapathVerification/BitHeap/Examples/Examples.lean
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,43 @@ info: 6
#eval (applyChain compressionChain fourBitsInCol1).eval
(show BitEnv from fun n => n = 1 || n = 2 || n = 3)

----------------------------

def exampleHeap1 : BitHeap 4 :=
let h := BitHeap.empty 4
let h := h.addBit 0 (Circuit.bit 0)
let h := h.addBit 1 (Circuit.bit 1)
let h := h.addBit 2 (Circuit.bit 2)
let h := h.addBit 3 (Circuit.bit 3)
h

def exampleHeap2 : BitHeap 4 :=
let h := BitHeap.empty 4
let h := h.addBit 0 (Circuit.bit 4)
let h := h.addBit 1 (Circuit.bit 5)
let h := h.addBit 2 (Circuit.bit 6)
let h := h.addBit 3 (Circuit.bit 7)
h

def exampleHeap3 : BitHeap 4 :=
let h := BitHeap.empty 4
let h := h.addBit 0 (Circuit.bit 8)
let h := h.addBit 1 (Circuit.bit 9)
let h := h.addBit 2 (Circuit.bit 10)
let h := h.addBit 3 (Circuit.bit 11)
h
/--
info: {0 ↦ [b4, b8, b0], 1 ↦ [b1, b5, b9], 2 ↦ [b2, b10, b6], 3 ↦ [b3, b11, b7]}
-/
#guard_msgs in
#eval addBitHeap [exampleHeap1, exampleHeap2, exampleHeap3]

/--
info: {0 ↦ [(b0 ∧ b4)], 1 ↦ [(b0 ∧ b5), (b1 ∧ b4)], 2 ↦ [(b1 ∧ b5), (b2 ∧ b4), (b0 ∧ b6)], 3 ↦ [(b3 ∧ b4), (b0 ∧ b7), (b2 ∧ b5), (b1 ∧ b6)], 4 ↦ [(b2 ∧ b6), (b3 ∧ b5), (b1 ∧ b7)], 5 ↦ [(b3 ∧ b6), (b2 ∧ b7)], 6 ↦ [(b3 ∧ b7)]}
-/
#guard_msgs in
#eval mulBitHeap exampleHeap1 exampleHeap2

end Examples

end BitHeap