diff --git a/examples/separable_basis.py b/examples/separable_basis.py new file mode 100644 index 00000000..805cce2a --- /dev/null +++ b/examples/separable_basis.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python3 +""" +Z (x) X separable basis, with per-vector phases at various angles. + +{'0p','0m','1p','1m'} is separable: qubit 0 in Z, qubit 1 in X. + +""" + +from qwerty import * + + +@qpu +def no_phase() -> bit[2]: + return '00' | {'0p', '0m', '1p', '1m'}.measure + +@qpu +def all_phase() -> bit[2]: + return '00' | {'0p'@90, '0m'@90, '1p'@90, '1m'@90}.measure + + +@qpu +def with_phase_180() -> bit[2]: + return '00' | {'0p', '0m', '1p', '1m'@180}.measure + + +@qpu +def with_phase_45() -> bit[2]: + return '00' | {'0p', '0m', '1p', '1m'@45}.measure + + +@qpu +def with_phase_60() -> bit[2]: + return '00' | {'0p', '0m', '1p', '1m'@60}.measure + + +if __name__ == '__main__': + # All four must agree (angle-independent under measurement). + print("no phase: ", histogram(no_phase(shots=512))) + print("all phase: ", histogram(all_phase(shots=512))) + print("with phase @180:", histogram(with_phase_180(shots=512))) + print("with phase @45: ", histogram(with_phase_45(shots=512))) + print("with phase @60: ", histogram(with_phase_60(shots=512))) diff --git a/qwerty_ast/src/ast/qpu.rs b/qwerty_ast/src/ast/qpu.rs index 2e4efbe7..95e6e518 100644 --- a/qwerty_ast/src/ast/qpu.rs +++ b/qwerty_ast/src/ast/qpu.rs @@ -2023,6 +2023,73 @@ impl Basis { rebuild!(Basis, self, canonicalize) } + pub fn strip_phases(self) -> Self { + match self { + Basis::BasisLiteral { vecs, dbg } => Basis::BasisLiteral { + vecs: vecs.into_iter().map(|v| v.canonicalize().normalize()).collect(), + dbg, + }, + Basis::BasisTensor { bases, dbg } => Basis::BasisTensor { + bases: bases.into_iter().map(Basis::strip_phases).collect(), + dbg, + }, + other => other, + } + } + + pub fn factor_separable(self) -> Self { + let Basis::BasisLiteral { vecs, dbg } = self else { + return self; + }; + if vecs.len() < 2 { + return Basis::BasisLiteral { vecs, dbg }; + } + + let mut rows: Vec> = Vec::with_capacity(vecs.len()); + for v in &vecs { + match v.clone().canonicalize() { + Vector::VectorTensor { qs, .. } => { rows.push(qs); } + _ => return Basis::BasisLiteral { vecs, dbg }, + } + } + + let n = rows[0].len(); + if rows.iter().any(|r| r.len() != n) { + return Basis::BasisLiteral { vecs, dbg }; + } + + let mut letters: Vec> = vec![Vec::new(); n]; + + for row in &rows { + for (i, atom) in row.iter().enumerate() { + if !letters[i].iter().any(|l| l.approx_equal(atom)) { + letters[i].push(atom.clone()); + } + } + } + + let expected: Vec> = letters + .clone() + .into_iter() + .multi_cartesian_product() + .collect(); + + let is_factorable = rows.len() == expected.len() + && rows.iter().zip(&expected).all(|(r, e)| { + r.len() == e.len() && r.iter().zip(e).all(|(a, b)| a.approx_equal(b)) + }); + + if !is_factorable { + return Basis::BasisLiteral { vecs, dbg }; + } + + let bases = letters + .into_iter() + .map(|vecs| Basis::BasisLiteral { vecs, dbg: dbg.clone() }) + .collect(); + Basis::BasisTensor { bases, dbg }.canonicalize() + } + pub(crate) fn canonicalize_rewriter(self) -> Self { match self { Basis::BasisLiteral { vecs, dbg } => { diff --git a/qwerty_ast_to_mlir/src/compile.rs b/qwerty_ast_to_mlir/src/compile.rs index 09987869..8d88ab62 100644 --- a/qwerty_ast_to_mlir/src/compile.rs +++ b/qwerty_ast_to_mlir/src/compile.rs @@ -19,6 +19,7 @@ use std::{env, fs, path::PathBuf}; const QWERTY_DEBUG_DIR: &str = "qwerty-debug"; const MLIR_DUMP_SUBDIR: &str = "mlir"; const INIT_MLIR_FILENAME: &str = "initial.mlir"; +const META_AST_FILENAME: &str = "meta_qwerty_ast.txt"; const QWERTY_AST_FILENAME: &str = "qwerty_ast.py"; const LLVM_IR_FILENAME: &str = "module.ll"; @@ -173,12 +174,18 @@ pub fn compile_meta_ast( func_name: &str, cfg: &CompileConfig, ) -> Result, CompileError> { + let dump_dir = create_debug_dump_dir(); + if cfg.dump { + let dump_path = dump_dir.join(META_AST_FILENAME); + eprintln!("Dumping MetaQwerty AST to file `{}`", dump_path.display()); + fs::write(&dump_path, format!("{:#?}", prog)).unwrap(); + } + let plain_ast = prog.lower(cfg.debug_lowering)?; plain_ast.typecheck()?; let canon_ast = plain_ast.canonicalize(); if cfg.dump { - let dump_dir = create_debug_dump_dir(); let dump_path = dump_dir.join(QWERTY_AST_FILENAME); eprintln!("Dumping Qwerty AST to file `{}`", dump_path.display()); diff --git a/qwerty_ast_to_mlir/src/lower/qpu.rs b/qwerty_ast_to_mlir/src/lower/qpu.rs index fc37465b..ef9e9e03 100644 --- a/qwerty_ast_to_mlir/src/lower/qpu.rs +++ b/qwerty_ast_to_mlir/src/lower/qpu.rs @@ -518,7 +518,8 @@ fn try_basis_as_primitive(basis_elems: &Vec) -> Option MlirBasis { - let basis_elements = basis.to_explicit().canonicalize().to_vec(); + let basis_elements = basis.to_explicit().canonicalize().to_vec().into_iter() + .flat_map(|b| b.factor_separable().to_vec()).collect(); let prim_basis = try_basis_as_primitive(&basis_elements); let (elems, phases) = if let Some(prim_basis) = prim_basis { @@ -1155,7 +1156,7 @@ fn ast_qpu_expr_to_mlir( explicit_indices, pad_indices, tgt_indices, - } = ast_basis_to_mlir(basis); + } = ast_basis_to_mlir(&basis.clone().strip_phases()); assert!(pad_indices.is_empty()); assert!(tgt_indices.is_empty()); assert_eq!(explicit_indices.len(), dim);