Skip to content

Commit 1fdee48

Browse files
authored
Minor improvements to GF arithmetic bloqs (#1705)
Some minor fixes: - Use `GF2MulViaKaratsuba` instead of `GF2Multiplication` in `GF2Inverse` to reduce the cost of inversion - Add support for symbolic gate counts to `GF2MulViaKaratsuba` - Reduce the number of GF2 multiplications in `GF2Inverse` by 1 since the last multiplication was not really needed. - Reduce the number of junk registers in `GF2Inverse` by 1 since `junk[0]` was same as the input `x`.
1 parent 3c163c0 commit 1fdee48

File tree

3 files changed

+46
-33
lines changed

3 files changed

+46
-33
lines changed

qualtran/bloqs/gf_arithmetic/gf2_inverse.py

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Signature,
2929
)
3030
from qualtran.bloqs.gf_arithmetic.gf2_addition import GF2Addition
31-
from qualtran.bloqs.gf_arithmetic.gf2_multiplication import GF2Multiplication
31+
from qualtran.bloqs.gf_arithmetic.gf2_multiplication import GF2MulViaKaratsuba
3232
from qualtran.bloqs.gf_arithmetic.gf2_square import GF2Square
3333
from qualtran.resource_counting.generalizers import ignore_alloc_free, ignore_split_join
3434
from qualtran.symbolics import bit_length, ceil, is_symbolic, log2, SymbolicInt
@@ -109,7 +109,7 @@ def qgf(self) -> QGF:
109109

110110
@cached_property
111111
def n_junk_regs(self) -> SymbolicInt:
112-
return 2 * bit_length(self.bitsize - 1) + self.bitsize_hamming_weight
112+
return 2 * bit_length(self.bitsize - 1) + self.bitsize_hamming_weight - 3
113113

114114
@cached_property
115115
def bitsize_hamming_weight(self) -> SymbolicInt:
@@ -139,11 +139,11 @@ def build_composite_bloq(self, bb: 'BloqBuilder', *, x: 'Soquet') -> Dict[str, '
139139
return {'x': x, 'result': result}
140140

141141
junk = []
142-
beta = bb.allocate(dtype=self.qgf)
143-
x, beta = bb.add(GF2Addition(self.bitsize), x=x, y=beta)
142+
beta = x
144143
is_first = True
145144
bitsize_minus_one = int(self.bitsize - 1)
146-
for i in range(bitsize_minus_one.bit_length()):
145+
n_iters = bitsize_minus_one.bit_length()
146+
for i in range(n_iters):
147147
if (1 << i) & bitsize_minus_one:
148148
if is_first:
149149
beta, result = bb.add(GF2Addition(self.bitsize), x=beta, y=result)
@@ -152,21 +152,23 @@ def build_composite_bloq(self, bb: 'BloqBuilder', *, x: 'Soquet') -> Dict[str, '
152152
for j in range(2**i):
153153
result = bb.add(GF2Square(self.bitsize), x=result)
154154
beta, result, new_result = bb.add(
155-
GF2Multiplication(self.bitsize), x=beta, y=result
155+
GF2MulViaKaratsuba(self.bitsize), x=beta, y=result
156156
)
157157
junk.append(result)
158158
result = new_result
159-
beta_squared = bb.allocate(dtype=self.qgf)
160-
beta, beta_squared = bb.add(GF2Addition(self.bitsize), x=beta, y=beta_squared)
161-
for j in range(2**i):
162-
beta_squared = bb.add(GF2Square(self.bitsize), x=beta_squared)
163-
beta, beta_squared, beta_new = bb.add(
164-
GF2Multiplication(self.bitsize), x=beta, y=beta_squared
165-
)
166-
junk.extend([beta, beta_squared])
167-
beta = beta_new
159+
if i != n_iters - 1:
160+
beta_squared = bb.allocate(dtype=self.qgf)
161+
beta, beta_squared = bb.add(GF2Addition(self.bitsize), x=beta, y=beta_squared)
162+
for j in range(2**i):
163+
beta_squared = bb.add(GF2Square(self.bitsize), x=beta_squared)
164+
beta, beta_squared, beta_new = bb.add(
165+
GF2MulViaKaratsuba(self.bitsize), x=beta, y=beta_squared
166+
)
167+
junk.extend([beta, beta_squared])
168+
beta = beta_new
168169
junk.append(beta)
169170
result = bb.add(GF2Square(self.bitsize), x=result)
171+
x = junk.pop(0)
170172
assert len(junk) == self.n_junk_regs, f'{len(junk)=}, {self.n_junk_regs=}'
171173
return {'x': x, 'result': result, 'junk': np.array(junk)}
172174

@@ -179,13 +181,12 @@ def build_call_graph(
179181
if not is_symbolic(self.bitsize):
180182
n = self.bitsize - 1
181183
square_count -= n & (-n)
184+
square_count -= 1 << (n.bit_length() - 1)
185+
mul_count = ceil(log2(self.bitsize)) + self.bitsize_hamming_weight - 2
182186
return {
183-
GF2Addition(self.bitsize): 2 + ceil(log2(self.bitsize)),
187+
GF2Addition(self.bitsize): ceil(log2(self.bitsize)),
184188
GF2Square(self.bitsize): square_count,
185-
GF2Multiplication(self.bitsize): ceil(log2(self.bitsize))
186-
+ self.bitsize_hamming_weight
187-
- 1,
188-
}
189+
} | ({GF2MulViaKaratsuba(self.bitsize): mul_count} if mul_count else {})
189190

190191
def on_classical_vals(self, *, x) -> Dict[str, 'ClassicalValT']:
191192
assert isinstance(x, self.qgf.gf_type)
@@ -204,10 +205,13 @@ def on_classical_vals(self, *, x) -> Dict[str, 'ClassicalValT']:
204205
result = result**2
205206
junk.append(result)
206207
result = result * beta
207-
beta_squared = beta ** (2 ** (2**i))
208-
junk.extend([beta, beta_squared])
209-
beta = beta * beta_squared
208+
if i != bitsize_minus_one.bit_length() - 1:
209+
beta_squared = beta ** (2 ** (2**i))
210+
junk.extend([beta, beta_squared])
211+
beta = beta * beta_squared
210212
junk.append(beta)
213+
assert x == junk[0]
214+
junk = junk[1:]
211215
return {'x': x, 'result': x ** (-1) if x else self.qgf.gf_type(0), 'junk': np.array(junk)}
212216

213217

qualtran/bloqs/gf_arithmetic/gf2_inverse_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@ def test_gf2_inverse_symbolic(bloq_autotester):
3838
def test_gf2_inverse_symbolic_toffoli_complexity():
3939
bloq = _gf2_inverse_symbolic.make()
4040
m = bloq.bitsize
41-
expected_expr = m**2 * (2 * ceil(log2(m)) - 1)
41+
expected_expr = m ** log2(3) * (2 * ceil(log2(m)) - 2)
4242
assert get_cost_value(bloq, QECGatesCost()).total_toffoli_only() - expected_expr == 0
43-
expected_expr = m * (3 * ceil(log2(m)) + 2)
43+
expected_expr = m * (3 * ceil(log2(m)) - 1)
4444
assert isinstance(expected_expr, sympy.Expr)
4545
assert sympy.simplify(get_cost_value(bloq, QubitCount()) - expected_expr) == 0
4646

@@ -60,7 +60,7 @@ def test_gf2_inverse_classical_sim(m):
6060
assert_consistent_classical_action(bloq, x=GFM.elements)
6161

6262

63-
@pytest.mark.parametrize('m', [*range(1, 12)])
63+
@pytest.mark.parametrize('m', [*range(1, 17)])
6464
def test_gf2_equivalent_bloq_counts(m):
6565
bloq = GF2Inverse(m)
6666
assert_equivalent_bloq_counts(bloq, generalizer=[ignore_split_join, ignore_alloc_free])

qualtran/bloqs/gf_arithmetic/gf2_multiplication.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def build_call_graph(
9494
def adjoint(self) -> 'SynthesizeLRCircuit':
9595
return attrs.evolve(self, is_adjoint=not self.is_adjoint)
9696

97+
def __str__(self):
98+
return f'{self.__class__.__name__}†' if self.is_adjoint else f'{self.__class__.__name__}'
99+
97100

98101
def _qgf_converter(x: Union[QGF, int, Poly, SymbolicInt, Sequence[int]]) -> QGF:
99102
if isinstance(x, QGF):
@@ -289,8 +292,8 @@ def m_x(self) -> Poly:
289292
return self.dtype.gf_type.irreducible_poly
290293

291294
@cached_property
292-
def n(self) -> int:
293-
return self.m_x.degree
295+
def n(self) -> SymbolicInt:
296+
return self.dtype.bitsize
294297

295298
@cached_property
296299
def qgf(self) -> QGF:
@@ -315,7 +318,7 @@ def lup(self) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
315318
by a full rank matrix that can be decomposed into PLU where L and U are lower
316319
and upper traingular matricies and P is a permutation matrix.
317320
"""
318-
n = self.n
321+
n = int(self.n)
319322
matrix = np.zeros((n, n), dtype=int)
320323
for i in range(n):
321324
p = self._const * self.galois_field(2**i)
@@ -970,12 +973,12 @@ def m_x(self):
970973
return self.dtype.gf_type.irreducible_poly
971974

972975
def __attrs_post_init__(self):
973-
if self.m_x.degree < 2:
976+
if not is_symbolic(self.dtype) and self.m_x.degree < 2:
974977
raise ValueError(f'GF2MulViaKaratsuba is not supported for {self.m_x}')
975978

976979
@cached_property
977-
def n(self):
978-
return int(self.m_x.degrees.max())
980+
def n(self) -> SymbolicInt:
981+
return self.dtype.bitsize
979982

980983
@cached_property
981984
def gf(self):
@@ -988,6 +991,9 @@ def qgf(self):
988991
def adjoint(self) -> 'GF2MulViaKaratsuba':
989992
return attrs.evolve(self, uncompute=not self.uncompute)
990993

994+
def __str__(self):
995+
return f'{self.__class__.__name__}†' if self.uncompute else f'{self.__class__.__name__}'
996+
991997
@cached_property
992998
def signature(self) -> 'Signature':
993999
# C is directional
@@ -1036,9 +1042,12 @@ def build_composite_bloq(
10361042
def build_call_graph(
10371043
self, ssa: 'SympySymbolAllocator'
10381044
) -> Union['BloqCountDictT', Set['BloqCountT']]:
1045+
if is_symbolic(self.n):
1046+
return {Toffoli(): self.n ** (log2(3)), CNOT(): self.n**2}
1047+
10391048
if self.n == 1:
10401049
return {Toffoli(): 1}
1041-
if not is_symbolic(self.n) and 2 * self.k == self.n:
1050+
if 2 * self.k == self.n:
10421051
return {
10431052
CNOT(): 4 * (self.n - self.k),
10441053
BinaryPolynomialMultiplication(self.k): 3,

0 commit comments

Comments
 (0)