Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Incorrect output pytree when using qml.counts() in specific output patterns #1016

Open
mehrdad2m opened this issue Aug 13, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@mehrdad2m
Copy link
Contributor

Context

when using qml.counts() in outputof a qunatum circuit with qjit, the out put pytree is modified to replace the output pytree element related to qml.counts with tree_structure(("keys", "counts")). However this tansformation is buggy and while it works in simple case, it mistransforms it in more complex patterns.

A example that works fine:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, ressult_tree = tree_flatten(result)
print(ressult_tree)

The result is as expected which is:

PyTreeDef({'1': (*, *)})

In the following, there are two patterns that result in wrong output pytree:

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return {"1": qml.counts()}, {"2": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, ressult_tree = tree_flatten(result)
print(ressult_tree)

results in:

PyTreeDef(((*, *), {'2': *}))

The expected pytree is:

PyTreeDef(({'1': (*, *)}, {'2': *}))

dev = qml.device("lightning.qubit", wires=1, shots=20)
@qjit
@qml.qnode(dev)
def circuit(x):
    qml.RX(x, wires=0)
    return [{"1": qml.expval(qml.Z(0))}, {"2": qml.counts()}], {"3": qml.expval(qml.Z(0))}

result = circuit(0.5)
_, ressult_tree = tree_flatten(result)
print(ressult_tree)

results in:

PyTreeDef(([{'1': *}, {'2': *}], (*, *)))

The expected pytree is:

PyTreeDef(([{'1': *}, {'2': (*, *)}], {'3': *}))

@mehrdad2m mehrdad2m added the bug Something isn't working label Aug 14, 2024
@josh146
Copy link
Member

josh146 commented Aug 18, 2024

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

@mehrdad2m
Copy link
Contributor Author

Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?

Hi @josh146, It is pretty straightforward. The fix should be done in trace_quantum_measurements which is where the output pytree is modified. Basically the simple version of the problem is to write a function replace_child_tree(tree, i, subtree) which recieved a pytree and would replace the ith node of tree that is visited in a DFS with sub_tree. The only tricky part is working with pytrees :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants