diff --git a/DatapathVerification/BitHeap/BVComb.lean b/DatapathVerification/BitHeap/BVComb.lean new file mode 100644 index 0000000..2354a4b --- /dev/null +++ b/DatapathVerification/BitHeap/BVComb.lean @@ -0,0 +1,71 @@ +import DatapathVerification.BitHeap.BitHeap +import DatapathVerification.BitHeap.Circuit +import DatapathVerification.BitHeap.DaddaTree + +open BitHeap +namespace Comb + +inductive ArithBinopKind +| add +| mul + +-- inductive ArithUnopKind +-- | neg + +-- inductive BooleanBinopKind +-- | and | or | xor + +inductive ArithCircuit + | var (width : Nat) (varIndex : Nat) + | arithbinop (kind : ArithBinopKind) (width : Nat) (l r : ArithCircuit) + -- | arithunop (kind : ArithUnopKind) (width : Nat) (arg : ArithCircuit) + -- | bvbinop (kind : BooleanBinopKind) (width : Nat) (l r : ArithCircuit) + +/-- +Convert a bitheap into a new bitheap that has a single row, +by building a Dadda tree of adders to reduce the bitheap to a single row. +-/ +def BitHeap.toSingleRow : BitHeap → CircuitVector + | bh => + let (pp1, pp2) := DaddaTree.DaddaTree bh + -- add pp1 and pp2 to get the final row + sorry + +namespace ArithCircuit +/-- +Given a bitvector (x : BV 3), but a bitheap +``` +* * * +x2 x1 x0 +``` +-/ +def bitheapOfVar (width : Nat) (varIndex : Nat) : BitHeap := + -- I want to create a bitheap that has one bit-variable per bit in the bitvector variable. + -- | We need to know that this index is unique which is a gigantic pain. + List.range width |>.foldl (fun bh i => bh.addBit i (BitHeap.Circuit.bit (varIndex * width + i))) BitHeap.empty + +def toBitHeap (c : ArithCircuit) : BitHeap := + match c with + | .var width varIndex => bitheapOfVar width varIndex + | .arithbinop kind width l r => + match kind with + | .add => (toBitHeap l).addBitHeap (toBitHeap r) + | .mul => (toBitHeap l).mulBitHeap (toBitHeap r) + -- | .arithunop kind width arg => + -- match kind with + -- | .neg => (toBitHeap arg).negBitHeap + -- | .bvbinop kind width l r => + -- match kind with + -- | .and => + -- let lRow := (l.toBitHeap).toSingleRow + -- let rRow := (r.toBitHeap).toSingleRow + -- let newRow := Array.zipWith (fun lBit rBit => Circuit.and lBit rBit) lRow rRow + -- BitHeap.fromRow newRow + +def toCircuitVector (c : ArithCircuit) : CircuitVector := + let bh := c.toBitHeap + bh.toSingleRow + +end ArithCircuit + +end Comb diff --git a/DatapathVerification/BitHeap/BitHeap.lean b/DatapathVerification/BitHeap/BitHeap.lean index 251ba5f..ee9cb84 100644 --- a/DatapathVerification/BitHeap/BitHeap.lean +++ b/DatapathVerification/BitHeap/BitHeap.lean @@ -91,6 +91,30 @@ def addBit (column : Nat) (c : Circuit) (h : BitHeap w) : BitHeap w := h.setColumn column (col.insert c) h1 else addBit (column + 1) c (h.removeBit column c) +-- TODO: make this variable size add +def addBitHeap' (h1 h2 : BitHeap w) : BitHeap w:= + let h := BitHeap.empty w + let h := h1.columns.zipIdx.foldl (fun acc (column, index) => + column.elems.toList.foldl (fun acc' c => acc'.addBit index c) acc) h + let h := h2.columns.zipIdx.foldl (fun acc (column, index) => + column.elems.toList.foldl (fun acc' c => acc'.addBit index c) acc) h + h + +def addBitHeap (bhs : List (BitHeap w)) : BitHeap w:= + let h := BitHeap.empty w + let h := bhs.foldl (fun acc heap => heap.columns.zipIdx.foldl (fun acc' (column, index) => + column.elems.toList.foldl (fun acc' c => acc'.addBit index c) acc') acc) h + h + +def mulBitHeap (h0 h1 : BitHeap w) : BitHeap (2 * w - 1) := + let h := BitHeap.empty (2 * w - 1) + let h := h0.columns.zipIdx.foldl (fun acc (column0, i0) => + h1.columns.zipIdx.foldl (fun acc' (column1, i2) => + column0.elems.toList.foldl (fun acc'' c1 => + column1.elems.toList.foldl (fun acc''' c2 => + acc'''.addBit (i0 + i2) (Circuit.binop .and c1 c2)) acc'') acc') acc) h + h + structure AdderResult (w : Nat) where heap : BitHeap w sum : Circuit diff --git a/DatapathVerification/BitHeap/Circuit.lean b/DatapathVerification/BitHeap/Circuit.lean index c34c06d..62cee0c 100644 --- a/DatapathVerification/BitHeap/Circuit.lean +++ b/DatapathVerification/BitHeap/Circuit.lean @@ -76,4 +76,16 @@ theorem eval_atLeastTwo (a b c : Circuit) (env : BitEnv) : end Circuit + +/-- A vector of circuits, used to represent symbolic BitVectors. -/ +def CircuitVector := Array Circuit + +namespace CircuitVector + +def eval (vec : CircuitVector) (env : Circuit.BitEnv) : Int := + (vec.mapIdx (fun i c => 2^i * (if c.eval env then 1 else 0))).sum + +end CircuitVector + + end BitHeap diff --git a/DatapathVerification/BitHeap/Examples/Examples.lean b/DatapathVerification/BitHeap/Examples/Examples.lean index 4702bae..186a5b8 100644 --- a/DatapathVerification/BitHeap/Examples/Examples.lean +++ b/DatapathVerification/BitHeap/Examples/Examples.lean @@ -99,6 +99,43 @@ info: 6 #eval (applyChain compressionChain fourBitsInCol1).eval (show BitEnv from fun n => n = 1 || n = 2 || n = 3) +---------------------------- + +def exampleHeap1 : BitHeap 4 := + let h := BitHeap.empty 4 + let h := h.addBit 0 (Circuit.bit 0) + let h := h.addBit 1 (Circuit.bit 1) + let h := h.addBit 2 (Circuit.bit 2) + let h := h.addBit 3 (Circuit.bit 3) + h + +def exampleHeap2 : BitHeap 4 := + let h := BitHeap.empty 4 + let h := h.addBit 0 (Circuit.bit 4) + let h := h.addBit 1 (Circuit.bit 5) + let h := h.addBit 2 (Circuit.bit 6) + let h := h.addBit 3 (Circuit.bit 7) + h + +def exampleHeap3 : BitHeap 4 := + let h := BitHeap.empty 4 + let h := h.addBit 0 (Circuit.bit 8) + let h := h.addBit 1 (Circuit.bit 9) + let h := h.addBit 2 (Circuit.bit 10) + let h := h.addBit 3 (Circuit.bit 11) + h +/-- +info: {0 ↦ [b4, b8, b0], 1 ↦ [b1, b5, b9], 2 ↦ [b2, b10, b6], 3 ↦ [b3, b11, b7]} +-/ +#guard_msgs in +#eval addBitHeap [exampleHeap1, exampleHeap2, exampleHeap3] + +/-- +info: {0 ↦ [(b0 ∧ b4)], 1 ↦ [(b0 ∧ b5), (b1 ∧ b4)], 2 ↦ [(b1 ∧ b5), (b2 ∧ b4), (b0 ∧ b6)], 3 ↦ [(b3 ∧ b4), (b0 ∧ b7), (b2 ∧ b5), (b1 ∧ b6)], 4 ↦ [(b2 ∧ b6), (b3 ∧ b5), (b1 ∧ b7)], 5 ↦ [(b3 ∧ b6), (b2 ∧ b7)], 6 ↦ [(b3 ∧ b7)]} +-/ +#guard_msgs in +#eval mulBitHeap exampleHeap1 exampleHeap2 + end Examples end BitHeap