Skip to content

Commit c2b73d6

Browse files
authored
beq/bne circuit to limb (#1155)
some left-over from previous migration
1 parent 7731e3b commit c2b73d6

File tree

3 files changed

+149
-39
lines changed

3 files changed

+149
-39
lines changed

ceno_zkvm/src/instructions/riscv/branch.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
use super::RIVInstruction;
22
use ceno_emul::InsnKind;
33

4+
#[cfg(not(feature = "u16limb_circuit"))]
45
mod branch_circuit;
6+
#[cfg(feature = "u16limb_circuit")]
57
mod branch_circuit_v2;
68
#[cfg(test)]
79
mod test;
@@ -10,12 +12,18 @@ pub struct BeqOp;
1012
impl RIVInstruction for BeqOp {
1113
const INST_KIND: InsnKind = InsnKind::BEQ;
1214
}
15+
#[cfg(feature = "u16limb_circuit")]
16+
pub type BeqInstruction<E> = branch_circuit_v2::BranchCircuit<E, BeqOp>;
17+
#[cfg(not(feature = "u16limb_circuit"))]
1318
pub type BeqInstruction<E> = branch_circuit::BranchCircuit<E, BeqOp>;
1419

1520
pub struct BneOp;
1621
impl RIVInstruction for BneOp {
1722
const INST_KIND: InsnKind = InsnKind::BNE;
1823
}
24+
#[cfg(feature = "u16limb_circuit")]
25+
pub type BneInstruction<E> = branch_circuit_v2::BranchCircuit<E, BneOp>;
26+
#[cfg(not(feature = "u16limb_circuit"))]
1927
pub type BneInstruction<E> = branch_circuit::BranchCircuit<E, BneOp>;
2028

2129
pub struct BltuOp;

ceno_zkvm/src/instructions/riscv/branch/branch_circuit_v2.rs

Lines changed: 127 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,21 @@ use crate::{
66
gadgets::{UIntLimbsLT, UIntLimbsLTConfig},
77
instructions::{
88
Instruction,
9-
riscv::{RIVInstruction, b_insn::BInstructionConfig, constants::UInt},
9+
riscv::{
10+
RIVInstruction,
11+
b_insn::BInstructionConfig,
12+
constants::{UINT_LIMBS, UInt},
13+
},
1014
},
1115
structs::ProgramParams,
1216
witness::LkMultiplicity,
1317
};
1418
use ceno_emul::{InsnKind, StepRecord};
15-
use ff_ext::ExtensionField;
16-
use multilinear_extensions::Expression;
17-
use std::marker::PhantomData;
19+
use ff_ext::{ExtensionField, FieldInto};
20+
use multilinear_extensions::{Expression, ToExpr, WitIn};
21+
use p3::field::{Field, FieldAlgebra};
22+
use std::{array, marker::PhantomData};
23+
use witness::set_val;
1824

1925
pub struct BranchCircuit<E, I>(PhantomData<(E, I)>);
2026

@@ -23,7 +29,11 @@ pub struct BranchConfig<E: ExtensionField> {
2329
pub read_rs1: UInt<E>,
2430
pub read_rs2: UInt<E>,
2531

26-
pub uint_lt_config: UIntLimbsLTConfig<E>,
32+
// for non eq opcode config
33+
pub uint_lt_config: Option<UIntLimbsLTConfig<E>>,
34+
// for beq/bne
35+
pub eq_diff_inv_marker: Option<[WitIn; UINT_LIMBS]>,
36+
pub eq_branch_taken_bit: Option<WitIn>,
2737
phantom: PhantomData<E>,
2838
}
2939

@@ -41,28 +51,82 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BranchCircuit<E, I
4151
let read_rs1 = UInt::new_unchecked(|| "rs1_limbs", circuit_builder)?;
4252
let read_rs2 = UInt::new_unchecked(|| "rs2_limbs", circuit_builder)?;
4353

44-
let is_signed = matches!(I::INST_KIND, InsnKind::BLT | InsnKind::BGE);
45-
let is_ge = matches!(I::INST_KIND, InsnKind::BGEU | InsnKind::BGE);
46-
let uint_lt_config =
47-
UIntLimbsLT::<E>::construct_circuit(circuit_builder, &read_rs1, &read_rs2, is_signed)?;
48-
let branch_taken_bit = if is_ge {
49-
Expression::ONE - uint_lt_config.is_lt()
50-
} else {
51-
uint_lt_config.is_lt()
52-
};
54+
let (branch_taken_bit_expr, eq_branch_taken_bit, eq_diff_inv_marker, uint_lt_config) =
55+
if matches!(I::INST_KIND, InsnKind::BEQ | InsnKind::BNE) {
56+
let branch_taken_bit = circuit_builder.create_bit(|| "branch_taken_bit")?;
57+
let eq_diff_inv_marker = array::from_fn(|i| {
58+
circuit_builder.create_witin(|| format!("eq_diff_inv_marker_{i}"))
59+
});
60+
61+
// 1 if cmp_result indicates a and b are EQUAL, 0 otherwise
62+
let cmp_eq = match I::INST_KIND {
63+
InsnKind::BEQ => branch_taken_bit.expr(),
64+
InsnKind::BNE => Expression::ONE - branch_taken_bit.expr(),
65+
_ => unreachable!(),
66+
};
67+
let mut sum = cmp_eq.expr();
68+
69+
// For BEQ, inv_marker is used to check equality of a and b:
70+
// - If a == b, all inv_marker values must be 0 (sum = 0)
71+
// - If a != b, inv_marker contains 0s for all positions except ONE position i where a[i] !=
72+
// b[i]
73+
// - At this position, inv_marker[i] contains the multiplicative inverse of (a[i] - b[i])
74+
// - This ensures inv_marker[i] * (a[i] - b[i]) = 1, making the sum = 1
75+
// Note: There might be multiple valid inv_marker if a != b.
76+
// But as long as the trace can provide at least one, that’s sufficient to prove a != b.
77+
//
78+
// Note:
79+
// - If cmp_eq == 0, then it is impossible to have sum != 0 if a == b.
80+
// - If cmp_eq == 1, then it is impossible for a[i] - b[i] == 0 to pass for all i if a != b.
81+
#[allow(clippy::needless_range_loop)]
82+
for i in 0..UINT_LIMBS {
83+
sum += (read_rs1.limbs[i].expr() - read_rs2.limbs[i].expr())
84+
* eq_diff_inv_marker[i].expr();
85+
circuit_builder.require_zero(
86+
|| "require_zero",
87+
cmp_eq.expr() * (read_rs1.limbs[i].expr() - read_rs2.limbs[i].expr()),
88+
)?
89+
}
90+
circuit_builder.require_one(|| "sum", sum)?;
91+
92+
(
93+
branch_taken_bit.expr(),
94+
Some(branch_taken_bit),
95+
Some(eq_diff_inv_marker),
96+
None,
97+
)
98+
} else {
99+
let is_signed = matches!(I::INST_KIND, InsnKind::BLT | InsnKind::BGE);
100+
let is_ge = matches!(I::INST_KIND, InsnKind::BGEU | InsnKind::BGE);
101+
let uint_lt_config = UIntLimbsLT::<E>::construct_circuit(
102+
circuit_builder,
103+
&read_rs1,
104+
&read_rs2,
105+
is_signed,
106+
)?;
107+
let branch_taken_bit = if is_ge {
108+
Expression::ONE - uint_lt_config.is_lt()
109+
} else {
110+
uint_lt_config.is_lt()
111+
};
112+
(branch_taken_bit, None, None, Some(uint_lt_config))
113+
};
114+
53115
let b_insn = BInstructionConfig::construct_circuit(
54116
circuit_builder,
55117
I::INST_KIND,
56118
read_rs1.register_expr(),
57119
read_rs2.register_expr(),
58-
branch_taken_bit,
120+
branch_taken_bit_expr,
59121
)?;
60122

61123
Ok(BranchConfig {
62124
b_insn,
63125
read_rs1,
64126
read_rs2,
65127
uint_lt_config,
128+
eq_branch_taken_bit,
129+
eq_diff_inv_marker,
66130
phantom: Default::default(),
67131
})
68132
}
@@ -85,15 +149,54 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BranchCircuit<E, I
85149
config.read_rs1.assign_limbs(instance, rs1_limbs);
86150
config.read_rs2.assign_limbs(instance, rs2_limbs);
87151

88-
let is_signed = matches!(step.insn().kind, InsnKind::BLT | InsnKind::BGE);
89-
UIntLimbsLT::<E>::assign(
90-
&config.uint_lt_config,
91-
instance,
92-
lk_multiplicity,
93-
rs1_limbs,
94-
rs2_limbs,
95-
is_signed,
96-
)?;
152+
if matches!(I::INST_KIND, InsnKind::BEQ | InsnKind::BNE) {
153+
// Returns (branch_taken, diff_idx, x[diff_idx] - y[diff_idx])
154+
#[inline(always)]
155+
fn run_eq<F, const NUM_LIMBS: usize>(
156+
is_beq: bool,
157+
x: &[u16],
158+
y: &[u16],
159+
) -> (bool, usize, F)
160+
where
161+
F: FieldAlgebra + Field,
162+
{
163+
for i in 0..NUM_LIMBS {
164+
if x[i] != y[i] {
165+
return (
166+
!is_beq,
167+
i,
168+
(F::from_canonical_u16(x[i]) - F::from_canonical_u16(y[i])).inverse(),
169+
);
170+
}
171+
}
172+
(is_beq, 0, F::ZERO)
173+
}
174+
let (branch_taken, diff_idx, diff_inv_val) = run_eq::<E::BaseField, UINT_LIMBS>(
175+
matches!(I::INST_KIND, InsnKind::BEQ),
176+
rs1_limbs,
177+
rs2_limbs,
178+
);
179+
set_val!(
180+
instance,
181+
config.eq_branch_taken_bit.as_ref().unwrap(),
182+
E::BaseField::from_bool(branch_taken)
183+
);
184+
set_val!(
185+
instance,
186+
config.eq_diff_inv_marker.as_ref().unwrap()[diff_idx],
187+
diff_inv_val
188+
);
189+
} else {
190+
let is_signed = matches!(step.insn().kind, InsnKind::BLT | InsnKind::BGE);
191+
UIntLimbsLT::<E>::assign(
192+
config.uint_lt_config.as_ref().unwrap(),
193+
instance,
194+
lk_multiplicity,
195+
rs1_limbs,
196+
rs2_limbs,
197+
is_signed,
198+
)?;
199+
}
97200
Ok(())
98201
}
99202
}

ceno_zkvm/src/instructions/riscv/branch/test.rs

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@ use crate::{
1313
structs::ProgramParams,
1414
};
1515

16-
const A: Word = 0xbead1010;
17-
const B: Word = 0xef552020;
18-
1916
#[test]
2017
fn test_opcode_beq() {
21-
impl_opcode_beq(false);
22-
impl_opcode_beq(true);
18+
impl_opcode_beq(false, 0xbead1010, 0xef552020);
19+
impl_opcode_beq(true, 0xef552020, 0xef552020);
20+
impl_opcode_beq(true, 0xffffffff, 0xffffffff);
2321
}
2422

25-
fn impl_opcode_beq(equal: bool) {
23+
fn impl_opcode_beq(take_branch: bool, a: u32, b: u32) {
2624
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
2725
let mut cb = CircuitBuilder::new(&mut cs);
2826
let config = cb
@@ -37,7 +35,7 @@ fn impl_opcode_beq(equal: bool) {
3735
.unwrap();
3836

3937
let insn_code = encode_rv32(InsnKind::BEQ, 2, 3, 0, 8);
40-
let pc_offset = if equal { 8 } else { PC_STEP_SIZE };
38+
let pc_offset = if take_branch { 8 } else { PC_STEP_SIZE };
4139
let (raw_witin, lkm) = BeqInstruction::assign_instances(
4240
&config,
4341
&mut ShardContext::default(),
@@ -47,8 +45,8 @@ fn impl_opcode_beq(equal: bool) {
4745
3,
4846
Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset),
4947
insn_code,
50-
A,
51-
if equal { A } else { B },
48+
a as Word,
49+
b as Word,
5250
0,
5351
)],
5452
)
@@ -59,11 +57,12 @@ fn impl_opcode_beq(equal: bool) {
5957

6058
#[test]
6159
fn test_opcode_bne() {
62-
impl_opcode_bne(false);
63-
impl_opcode_bne(true);
60+
impl_opcode_bne(true, 0xbead1010, 0xef552020);
61+
impl_opcode_bne(false, 0xef552020, 0xef552020);
62+
impl_opcode_bne(false, 0xffffffff, 0xffffffff);
6463
}
6564

66-
fn impl_opcode_bne(equal: bool) {
65+
fn impl_opcode_bne(take_branch: bool, a: u32, b: u32) {
6766
let mut cs = ConstraintSystem::<GoldilocksExt2>::new(|| "riscv");
6867
let mut cb = CircuitBuilder::new(&mut cs);
6968
let config = cb
@@ -78,7 +77,7 @@ fn impl_opcode_bne(equal: bool) {
7877
.unwrap();
7978

8079
let insn_code = encode_rv32(InsnKind::BNE, 2, 3, 0, 8);
81-
let pc_offset = if equal { PC_STEP_SIZE } else { 8 };
80+
let pc_offset = if take_branch { 8 } else { PC_STEP_SIZE };
8281
let (raw_witin, lkm) = BneInstruction::assign_instances(
8382
&config,
8483
&mut ShardContext::default(),
@@ -88,8 +87,8 @@ fn impl_opcode_bne(equal: bool) {
8887
3,
8988
Change::new(MOCK_PC_START, MOCK_PC_START + pc_offset),
9089
insn_code,
91-
A,
92-
if equal { A } else { B },
90+
a as Word,
91+
b as Word,
9392
0,
9493
)],
9594
)

0 commit comments

Comments
 (0)