@@ -6,15 +6,21 @@ use crate::{
66 gadgets:: { UIntLimbsLT , UIntLimbsLTConfig } ,
77 instructions:: {
88 Instruction ,
9- riscv:: { RIVInstruction , b_insn:: BInstructionConfig , constants:: UInt } ,
9+ riscv:: {
10+ RIVInstruction ,
11+ b_insn:: BInstructionConfig ,
12+ constants:: { UINT_LIMBS , UInt } ,
13+ } ,
1014 } ,
1115 structs:: ProgramParams ,
1216 witness:: LkMultiplicity ,
1317} ;
1418use ceno_emul:: { InsnKind , StepRecord } ;
15- use ff_ext:: ExtensionField ;
16- use multilinear_extensions:: Expression ;
17- use std:: marker:: PhantomData ;
19+ use ff_ext:: { ExtensionField , FieldInto } ;
20+ use multilinear_extensions:: { Expression , ToExpr , WitIn } ;
21+ use p3:: field:: { Field , FieldAlgebra } ;
22+ use std:: { array, marker:: PhantomData } ;
23+ use witness:: set_val;
1824
1925pub struct BranchCircuit < E , I > ( PhantomData < ( E , I ) > ) ;
2026
@@ -23,7 +29,11 @@ pub struct BranchConfig<E: ExtensionField> {
2329 pub read_rs1 : UInt < E > ,
2430 pub read_rs2 : UInt < E > ,
2531
26- pub uint_lt_config : UIntLimbsLTConfig < E > ,
32+ // for non eq opcode config
33+ pub uint_lt_config : Option < UIntLimbsLTConfig < E > > ,
34+ // for beq/bne
35+ pub eq_diff_inv_marker : Option < [ WitIn ; UINT_LIMBS ] > ,
36+ pub eq_branch_taken_bit : Option < WitIn > ,
2737 phantom : PhantomData < E > ,
2838}
2939
@@ -41,28 +51,82 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BranchCircuit<E, I
4151 let read_rs1 = UInt :: new_unchecked ( || "rs1_limbs" , circuit_builder) ?;
4252 let read_rs2 = UInt :: new_unchecked ( || "rs2_limbs" , circuit_builder) ?;
4353
44- let is_signed = matches ! ( I :: INST_KIND , InsnKind :: BLT | InsnKind :: BGE ) ;
45- let is_ge = matches ! ( I :: INST_KIND , InsnKind :: BGEU | InsnKind :: BGE ) ;
46- let uint_lt_config =
47- UIntLimbsLT :: < E > :: construct_circuit ( circuit_builder, & read_rs1, & read_rs2, is_signed) ?;
48- let branch_taken_bit = if is_ge {
49- Expression :: ONE - uint_lt_config. is_lt ( )
50- } else {
51- uint_lt_config. is_lt ( )
52- } ;
54+ let ( branch_taken_bit_expr, eq_branch_taken_bit, eq_diff_inv_marker, uint_lt_config) =
55+ if matches ! ( I :: INST_KIND , InsnKind :: BEQ | InsnKind :: BNE ) {
56+ let branch_taken_bit = circuit_builder. create_bit ( || "branch_taken_bit" ) ?;
57+ let eq_diff_inv_marker = array:: from_fn ( |i| {
58+ circuit_builder. create_witin ( || format ! ( "eq_diff_inv_marker_{i}" ) )
59+ } ) ;
60+
61+ // 1 if cmp_result indicates a and b are EQUAL, 0 otherwise
62+ let cmp_eq = match I :: INST_KIND {
63+ InsnKind :: BEQ => branch_taken_bit. expr ( ) ,
64+ InsnKind :: BNE => Expression :: ONE - branch_taken_bit. expr ( ) ,
65+ _ => unreachable ! ( ) ,
66+ } ;
67+ let mut sum = cmp_eq. expr ( ) ;
68+
69+ // For BEQ, inv_marker is used to check equality of a and b:
70+ // - If a == b, all inv_marker values must be 0 (sum = 0)
71+ // - If a != b, inv_marker contains 0s for all positions except ONE position i where a[i] !=
72+ // b[i]
73+ // - At this position, inv_marker[i] contains the multiplicative inverse of (a[i] - b[i])
74+ // - This ensures inv_marker[i] * (a[i] - b[i]) = 1, making the sum = 1
75+ // Note: There might be multiple valid inv_marker if a != b.
76+ // But as long as the trace can provide at least one, that’s sufficient to prove a != b.
77+ //
78+ // Note:
79+ // - If cmp_eq == 0, then it is impossible to have sum != 0 if a == b.
80+ // - If cmp_eq == 1, then it is impossible for a[i] - b[i] == 0 to pass for all i if a != b.
81+ #[ allow( clippy:: needless_range_loop) ]
82+ for i in 0 ..UINT_LIMBS {
83+ sum += ( read_rs1. limbs [ i] . expr ( ) - read_rs2. limbs [ i] . expr ( ) )
84+ * eq_diff_inv_marker[ i] . expr ( ) ;
85+ circuit_builder. require_zero (
86+ || "require_zero" ,
87+ cmp_eq. expr ( ) * ( read_rs1. limbs [ i] . expr ( ) - read_rs2. limbs [ i] . expr ( ) ) ,
88+ ) ?
89+ }
90+ circuit_builder. require_one ( || "sum" , sum) ?;
91+
92+ (
93+ branch_taken_bit. expr ( ) ,
94+ Some ( branch_taken_bit) ,
95+ Some ( eq_diff_inv_marker) ,
96+ None ,
97+ )
98+ } else {
99+ let is_signed = matches ! ( I :: INST_KIND , InsnKind :: BLT | InsnKind :: BGE ) ;
100+ let is_ge = matches ! ( I :: INST_KIND , InsnKind :: BGEU | InsnKind :: BGE ) ;
101+ let uint_lt_config = UIntLimbsLT :: < E > :: construct_circuit (
102+ circuit_builder,
103+ & read_rs1,
104+ & read_rs2,
105+ is_signed,
106+ ) ?;
107+ let branch_taken_bit = if is_ge {
108+ Expression :: ONE - uint_lt_config. is_lt ( )
109+ } else {
110+ uint_lt_config. is_lt ( )
111+ } ;
112+ ( branch_taken_bit, None , None , Some ( uint_lt_config) )
113+ } ;
114+
53115 let b_insn = BInstructionConfig :: construct_circuit (
54116 circuit_builder,
55117 I :: INST_KIND ,
56118 read_rs1. register_expr ( ) ,
57119 read_rs2. register_expr ( ) ,
58- branch_taken_bit ,
120+ branch_taken_bit_expr ,
59121 ) ?;
60122
61123 Ok ( BranchConfig {
62124 b_insn,
63125 read_rs1,
64126 read_rs2,
65127 uint_lt_config,
128+ eq_branch_taken_bit,
129+ eq_diff_inv_marker,
66130 phantom : Default :: default ( ) ,
67131 } )
68132 }
@@ -85,15 +149,54 @@ impl<E: ExtensionField, I: RIVInstruction> Instruction<E> for BranchCircuit<E, I
85149 config. read_rs1 . assign_limbs ( instance, rs1_limbs) ;
86150 config. read_rs2 . assign_limbs ( instance, rs2_limbs) ;
87151
88- let is_signed = matches ! ( step. insn( ) . kind, InsnKind :: BLT | InsnKind :: BGE ) ;
89- UIntLimbsLT :: < E > :: assign (
90- & config. uint_lt_config ,
91- instance,
92- lk_multiplicity,
93- rs1_limbs,
94- rs2_limbs,
95- is_signed,
96- ) ?;
152+ if matches ! ( I :: INST_KIND , InsnKind :: BEQ | InsnKind :: BNE ) {
153+ // Returns (branch_taken, diff_idx, x[diff_idx] - y[diff_idx])
154+ #[ inline( always) ]
155+ fn run_eq < F , const NUM_LIMBS : usize > (
156+ is_beq : bool ,
157+ x : & [ u16 ] ,
158+ y : & [ u16 ] ,
159+ ) -> ( bool , usize , F )
160+ where
161+ F : FieldAlgebra + Field ,
162+ {
163+ for i in 0 ..NUM_LIMBS {
164+ if x[ i] != y[ i] {
165+ return (
166+ !is_beq,
167+ i,
168+ ( F :: from_canonical_u16 ( x[ i] ) - F :: from_canonical_u16 ( y[ i] ) ) . inverse ( ) ,
169+ ) ;
170+ }
171+ }
172+ ( is_beq, 0 , F :: ZERO )
173+ }
174+ let ( branch_taken, diff_idx, diff_inv_val) = run_eq :: < E :: BaseField , UINT_LIMBS > (
175+ matches ! ( I :: INST_KIND , InsnKind :: BEQ ) ,
176+ rs1_limbs,
177+ rs2_limbs,
178+ ) ;
179+ set_val ! (
180+ instance,
181+ config. eq_branch_taken_bit. as_ref( ) . unwrap( ) ,
182+ E :: BaseField :: from_bool( branch_taken)
183+ ) ;
184+ set_val ! (
185+ instance,
186+ config. eq_diff_inv_marker. as_ref( ) . unwrap( ) [ diff_idx] ,
187+ diff_inv_val
188+ ) ;
189+ } else {
190+ let is_signed = matches ! ( step. insn( ) . kind, InsnKind :: BLT | InsnKind :: BGE ) ;
191+ UIntLimbsLT :: < E > :: assign (
192+ config. uint_lt_config . as_ref ( ) . unwrap ( ) ,
193+ instance,
194+ lk_multiplicity,
195+ rs1_limbs,
196+ rs2_limbs,
197+ is_signed,
198+ ) ?;
199+ }
97200 Ok ( ( ) )
98201 }
99202}
0 commit comments