Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/separable_basis.py

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great test, can you change it to being an integration test?

Here is an example test you can imitate:

  1. program itself (equivalent to this file you've written, except the if __name__ == '__main__' code is inside a function called test() instead (i usually make it take a shots argument): https://github.com/gt-tinker/qwerty/blob/main/qwerty_pyrt/python/qwerty/tests/integ/meta/float_expr.py
  2. the "unit" test that calls the code and verifies the output: https://github.com/gt-tinker/qwerty/blob/main/qwerty_pyrt/python/qwerty/tests/integration_tests.py#L429-L436

Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#!/usr/bin/env python3
"""
Z (x) X separable basis, with per-vector phases at various angles.

{'0p','0m','1p','1m'} is separable: qubit 0 in Z, qubit 1 in X.

"""

from qwerty import *


@qpu
def no_phase() -> bit[2]:
return '00' | {'0p', '0m', '1p', '1m'}.measure

@qpu
def all_phase() -> bit[2]:
return '00' | {'0p'@90, '0m'@90, '1p'@90, '1m'@90}.measure


@qpu
def with_phase_180() -> bit[2]:
return '00' | {'0p', '0m', '1p', '1m'@180}.measure


@qpu
def with_phase_45() -> bit[2]:
return '00' | {'0p', '0m', '1p', '1m'@45}.measure


@qpu
def with_phase_60() -> bit[2]:
return '00' | {'0p', '0m', '1p', '1m'@60}.measure


if __name__ == '__main__':
# All four must agree (angle-independent under measurement).
print("no phase: ", histogram(no_phase(shots=512)))
print("all phase: ", histogram(all_phase(shots=512)))
print("with phase @180:", histogram(with_phase_180(shots=512)))
print("with phase @45: ", histogram(with_phase_45(shots=512)))
print("with phase @60: ", histogram(with_phase_60(shots=512)))
67 changes: 67 additions & 0 deletions qwerty_ast/src/ast/qpu.rs

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For any recursive code you add, can you write a TODO saying e.g. // TODO: use gen_rebuild to this non-recursively? I don't want to make you worry about it right now, but tldr recursive code can cause compile-time stack overflows for deep ASTs. We have crazy macros to work around that, but it's not worth your time atm considering the qwerty_ast_to_mlir code is already comically recursive

Original file line number Diff line number Diff line change
Expand Up @@ -2023,6 +2023,73 @@ impl Basis {
rebuild!(Basis, self, canonicalize)
}

pub fn strip_phases(self) -> Self {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may be missing something, but I don't see why we should remove vector phases.

I added code to remove (outer) phases for span checking, since that information is redundant with respect to span ($e^{i \theta}\ket{\mathrm{bv}}$ has the same span as $\ket{\mathrm{bv}}$), but for basis translation synthesis, the phases are actually important. For example, {'1'} >> {'1'} is an identity, whereas {'1'} >> {-'1'} is a Z gate

match self {
Basis::BasisLiteral { vecs, dbg } => Basis::BasisLiteral {
vecs: vecs.into_iter().map(|v| v.canonicalize().normalize()).collect(),
dbg,
},
Basis::BasisTensor { bases, dbg } => Basis::BasisTensor {
bases: bases.into_iter().map(Basis::strip_phases).collect(),
dbg,
},
other => other,
}
}

pub fn factor_separable(self) -> Self {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a doc comment for this?

let Basis::BasisLiteral { vecs, dbg } = self else {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is cool. I have never used this syntax before, respect.

return self;
};
if vecs.len() < 2 {
return Basis::BasisLiteral { vecs, dbg };
}

let mut rows: Vec<Vec<Vector>> = Vec::with_capacity(vecs.len());
for v in &vecs {
match v.clone().canonicalize() {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we mandate that canonicalization happens on the vectors before this (in the doc comment for this method), then we can skip this clone & canonicalization, since the vectors will already be canonicalized

Vector::VectorTensor { qs, .. } => { rows.push(qs); }
_ => return Basis::BasisLiteral { vecs, dbg },
}
}

let n = rows[0].len();
if rows.iter().any(|r| r.len() != n) {
return Basis::BasisLiteral { vecs, dbg };
}

let mut letters: Vec<Vec<Vector>> = vec![Vec::new(); n];

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to leave a comment with an example describing what this variable is holding


for row in &rows {
for (i, atom) in row.iter().enumerate() {
if !letters[i].iter().any(|l| l.approx_equal(atom)) {
letters[i].push(atom.clone());
Comment on lines +2065 to +2066

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dang, I am trying pretty hard, but I can't figure out what this is doing. Maybe we should discuss in a call

}
}
}

let expected: Vec<Vec<Vector>> = letters
.clone()
.into_iter()
Comment on lines +2072 to +2073

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to delay the clone here, i.e., letters.iter().multi_cartesian_product().cloned().collect()? https://stackoverflow.com/q/35354716/321301

.multi_cartesian_product()

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably add some kind of short-circuit check somewhere to bypass this optimization if letters is too long, since this going to be quite large and expensive.

Do we really need to allocate something this big?

.collect();

let is_factorable = rows.len() == expected.len()
&& rows.iter().zip(&expected).all(|(r, e)| {
r.len() == e.len() && r.iter().zip(e).all(|(a, b)| a.approx_equal(b))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I almost get it I think, but I just can't figure out what this does. We should talk on a call I think

});

if !is_factorable {
return Basis::BasisLiteral { vecs, dbg };
}

let bases = letters
.into_iter()
.map(|vecs| Basis::BasisLiteral { vecs, dbg: dbg.clone() })
.collect();
Basis::BasisTensor { bases, dbg }.canonicalize()
}

pub(crate) fn canonicalize_rewriter(self) -> Self {
match self {
Basis::BasisLiteral { vecs, dbg } => {
Expand Down
9 changes: 8 additions & 1 deletion qwerty_ast_to_mlir/src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::{env, fs, path::PathBuf};
const QWERTY_DEBUG_DIR: &str = "qwerty-debug";
const MLIR_DUMP_SUBDIR: &str = "mlir";
const INIT_MLIR_FILENAME: &str = "initial.mlir";
const META_AST_FILENAME: &str = "meta_qwerty_ast.txt";
const QWERTY_AST_FILENAME: &str = "qwerty_ast.py";
const LLVM_IR_FILENAME: &str = "module.ll";

Expand Down Expand Up @@ -173,12 +174,18 @@ pub fn compile_meta_ast(
func_name: &str,
cfg: &CompileConfig,
) -> Result<Module<'static>, CompileError> {
let dump_dir = create_debug_dump_dir();
if cfg.dump {
let dump_path = dump_dir.join(META_AST_FILENAME);
eprintln!("Dumping MetaQwerty AST to file `{}`", dump_path.display());
fs::write(&dump_path, format!("{:#?}", prog)).unwrap();
}

let plain_ast = prog.lower(cfg.debug_lowering)?;
plain_ast.typecheck()?;
let canon_ast = plain_ast.canonicalize();

if cfg.dump {
let dump_dir = create_debug_dump_dir();
let dump_path = dump_dir.join(QWERTY_AST_FILENAME);
eprintln!("Dumping Qwerty AST to file `{}`", dump_path.display());

Expand Down
5 changes: 3 additions & 2 deletions qwerty_ast_to_mlir/src/lower/qpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,8 @@ fn try_basis_as_primitive(basis_elems: &Vec<Basis>) -> Option<qwerty::PrimitiveB
/// of phases which correspond one-to-one with any vectors that have
/// hasPhase==true.
fn ast_basis_to_mlir(basis: &Basis) -> MlirBasis {
let basis_elements = basis.to_explicit().canonicalize().to_vec();
let basis_elements = basis.to_explicit().canonicalize().to_vec().into_iter()
.flat_map(|b| b.factor_separable().to_vec()).collect();

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any reason we shouldn't do this factoring as part of canonicalization? Here's what we mean by "canonicalization", by the way: https://sunfishcode.github.io/blog/2018/10/22/Canonicalization.html

Do you see what I mean? If not let me know

The only problem I can think of is that it might confuse the type checker when doing span checking, but hopefully in that case, we just fix the span checker. (I can do that if it happens)


let prim_basis = try_basis_as_primitive(&basis_elements);
let (elems, phases) = if let Some(prim_basis) = prim_basis {
Expand Down Expand Up @@ -1155,7 +1156,7 @@ fn ast_qpu_expr_to_mlir(
explicit_indices,
pad_indices,
tgt_indices,
} = ast_basis_to_mlir(basis);
} = ast_basis_to_mlir(&basis.clone().strip_phases());
assert!(pad_indices.is_empty());
assert!(tgt_indices.is_empty());
assert_eq!(explicit_indices.len(), dim);
Expand Down