2828 Signature ,
2929)
3030from qualtran .bloqs .gf_arithmetic .gf2_addition import GF2Addition
31- from qualtran .bloqs .gf_arithmetic .gf2_multiplication import GF2Multiplication
31+ from qualtran .bloqs .gf_arithmetic .gf2_multiplication import GF2MulViaKaratsuba
3232from qualtran .bloqs .gf_arithmetic .gf2_square import GF2Square
3333from qualtran .resource_counting .generalizers import ignore_alloc_free , ignore_split_join
3434from qualtran .symbolics import bit_length , ceil , is_symbolic , log2 , SymbolicInt
@@ -109,7 +109,7 @@ def qgf(self) -> QGF:
109109
110110 @cached_property
111111 def n_junk_regs (self ) -> SymbolicInt :
112- return 2 * bit_length (self .bitsize - 1 ) + self .bitsize_hamming_weight
112+ return 2 * bit_length (self .bitsize - 1 ) + self .bitsize_hamming_weight - 3
113113
114114 @cached_property
115115 def bitsize_hamming_weight (self ) -> SymbolicInt :
@@ -139,11 +139,11 @@ def build_composite_bloq(self, bb: 'BloqBuilder', *, x: 'Soquet') -> Dict[str, '
139139 return {'x' : x , 'result' : result }
140140
141141 junk = []
142- beta = bb .allocate (dtype = self .qgf )
143- x , beta = bb .add (GF2Addition (self .bitsize ), x = x , y = beta )
142+ beta = x
144143 is_first = True
145144 bitsize_minus_one = int (self .bitsize - 1 )
146- for i in range (bitsize_minus_one .bit_length ()):
145+ n_iters = bitsize_minus_one .bit_length ()
146+ for i in range (n_iters ):
147147 if (1 << i ) & bitsize_minus_one :
148148 if is_first :
149149 beta , result = bb .add (GF2Addition (self .bitsize ), x = beta , y = result )
@@ -152,21 +152,23 @@ def build_composite_bloq(self, bb: 'BloqBuilder', *, x: 'Soquet') -> Dict[str, '
152152 for j in range (2 ** i ):
153153 result = bb .add (GF2Square (self .bitsize ), x = result )
154154 beta , result , new_result = bb .add (
155- GF2Multiplication (self .bitsize ), x = beta , y = result
155+ GF2MulViaKaratsuba (self .bitsize ), x = beta , y = result
156156 )
157157 junk .append (result )
158158 result = new_result
159- beta_squared = bb .allocate (dtype = self .qgf )
160- beta , beta_squared = bb .add (GF2Addition (self .bitsize ), x = beta , y = beta_squared )
161- for j in range (2 ** i ):
162- beta_squared = bb .add (GF2Square (self .bitsize ), x = beta_squared )
163- beta , beta_squared , beta_new = bb .add (
164- GF2Multiplication (self .bitsize ), x = beta , y = beta_squared
165- )
166- junk .extend ([beta , beta_squared ])
167- beta = beta_new
159+ if i != n_iters - 1 :
160+ beta_squared = bb .allocate (dtype = self .qgf )
161+ beta , beta_squared = bb .add (GF2Addition (self .bitsize ), x = beta , y = beta_squared )
162+ for j in range (2 ** i ):
163+ beta_squared = bb .add (GF2Square (self .bitsize ), x = beta_squared )
164+ beta , beta_squared , beta_new = bb .add (
165+ GF2MulViaKaratsuba (self .bitsize ), x = beta , y = beta_squared
166+ )
167+ junk .extend ([beta , beta_squared ])
168+ beta = beta_new
168169 junk .append (beta )
169170 result = bb .add (GF2Square (self .bitsize ), x = result )
171+ x = junk .pop (0 )
170172 assert len (junk ) == self .n_junk_regs , f'{ len (junk )= } , { self .n_junk_regs = } '
171173 return {'x' : x , 'result' : result , 'junk' : np .array (junk )}
172174
@@ -179,13 +181,12 @@ def build_call_graph(
179181 if not is_symbolic (self .bitsize ):
180182 n = self .bitsize - 1
181183 square_count -= n & (- n )
184+ square_count -= 1 << (n .bit_length () - 1 )
185+ mul_count = ceil (log2 (self .bitsize )) + self .bitsize_hamming_weight - 2
182186 return {
183- GF2Addition (self .bitsize ): 2 + ceil (log2 (self .bitsize )),
187+ GF2Addition (self .bitsize ): ceil (log2 (self .bitsize )),
184188 GF2Square (self .bitsize ): square_count ,
185- GF2Multiplication (self .bitsize ): ceil (log2 (self .bitsize ))
186- + self .bitsize_hamming_weight
187- - 1 ,
188- }
189+ } | ({GF2MulViaKaratsuba (self .bitsize ): mul_count } if mul_count else {})
189190
190191 def on_classical_vals (self , * , x ) -> Dict [str , 'ClassicalValT' ]:
191192 assert isinstance (x , self .qgf .gf_type )
@@ -204,10 +205,13 @@ def on_classical_vals(self, *, x) -> Dict[str, 'ClassicalValT']:
204205 result = result ** 2
205206 junk .append (result )
206207 result = result * beta
207- beta_squared = beta ** (2 ** (2 ** i ))
208- junk .extend ([beta , beta_squared ])
209- beta = beta * beta_squared
208+ if i != bitsize_minus_one .bit_length () - 1 :
209+ beta_squared = beta ** (2 ** (2 ** i ))
210+ junk .extend ([beta , beta_squared ])
211+ beta = beta * beta_squared
210212 junk .append (beta )
213+ assert x == junk [0 ]
214+ junk = junk [1 :]
211215 return {'x' : x , 'result' : x ** (- 1 ) if x else self .qgf .gf_type (0 ), 'junk' : np .array (junk )}
212216
213217
0 commit comments