Open mehrdad2m opened 3 months ago
Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?
Thanks for catching this @mehrdad2m! How involved would you say the fix is -- is it straightforward, or would it require some exploration?
Hi @josh146, It is pretty straightforward. The fix should be done in trace_quantum_measurements which is where the output pytree is modified. Basically the simple version of the problem is to write a function replace_child_tree(tree, i, subtree) which recieved a pytree and would replace the ith node of tree that is visited in a DFS with sub_tree. The only tricky part is working with pytrees :)
Context
When using
qml.counts()
in the output of a quantum circuit withqjit
, the output pytree is modified to replace the output pytree element related toqml.counts
withtree_structure(("keys", "counts"))
. However this transformation is buggy and while it works for simple cases, it incorrectly transforms more complex patterns.An example that works fine:
The result is as expected:
PyTreeDef({'1': (*, *)})
In the following example, there are two patterns that result in the wrong output pytree:
1.
results in:
instead of the expected pytree of:
2.
results in:
while the expected pytree is:
A possible solution would update
trace_quantum_measurements()
, which is where the output pytree is modified. You could write a functionreplace_child_tree(tree, i, subtree)
which receives a pytree and would replace theith
node of the tree that is visited in a DFS ofsubtree
.