PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

Feature-pass-static-argnums-to-qnode #932

Closed mehrdad2m closed 1 month ago

mehrdad2m commented 1 month ago

Context: Static argnum is not correctly passed through QNode. Currently if we use @qjit(static_argnums=(1,)) decorator right before @qml.qnode(dev), the static argument will still be traced within the QNode.

Description of the Change: Passed the static_argnums from the compile options through qnode call() method and then using it in deduce_avals to avoid tracing the static arguments. Also improved the verification and preparation of static_argnums. Benefits: Added ability to use static arguments inside qnodes

Possible Drawbacks: It can potentially cause confusion for jax users since passing static_argnums through nested calls to jax.jit() is not supported in jax. e.g.

@partial(jax.jit, static_argnums=(1,))
@jax.jit
def foo(x, c):
    print("Inside QNode:", c)
    return x + c
>>> foo(0.5, 0.5)  
>>> Inside QNode: Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)>

which means that parameter c inside foo is still traced.

Related GitHub Issues: https://github.com/PennyLaneAI/catalyst/issues/902

[sc-67808]

codecov[bot] commented 1 month ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 97.92%. Comparing base (d3b8cd4) to head (2a8dc11). Report is 1 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #932 +/- ## ======================================= Coverage 97.92% 97.92% ======================================= Files 71 71 Lines 10255 10281 +26 Branches 1163 1169 +6 ======================================= + Hits 10042 10068 +26 Misses 169 169 Partials 44 44 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

erick-xanadu commented 1 month ago

I think there is this other test that I came up which produces an error. I am not sure if we should consider out of scope. May be good to bring @josh146 and @dime10 into the loop:

+    def test_qnode_nested_not_qnode(self, capsys):
+        """Test if QJIT static arguments pass through QNode correctly when parameters are switched."""
+        dev = qml.device("lightning.qubit", wires=1)
+
+        @qjit(static_argnums=(0,))
+        def circuit(c, x):
+            return 2 * c
+
+        @qjit(static_argnums=(1,))
+        def wrapper(x, c):
+            return circuit(c, x)
+
+        wrapper(0.5, 0.5)
+        print(wrapper.mlir)
+

I tried adding static_argnums to a non-qnode function and I get the following error:

E           TypeError: TestStaticArguments.test_qnode_nested_not_qnode.<locals>.circuit() got an unexpected keyword argument 'static_argnums'

I think having this feature work might be out of scope, but I think that we will like it at some point. Other than that, I think it looks good!

mehrdad2m commented 1 month ago

I think there is this other test that I came up which produces an error. I am not sure if we should consider out of scope. May be good to bring @josh146 and @dime10 into the loop:

+    def test_qnode_nested_not_qnode(self, capsys):
+        """Test if QJIT static arguments pass through QNode correctly when parameters are switched."""
+        dev = qml.device("lightning.qubit", wires=1)
+
+        @qjit(static_argnums=(0,))
+        def circuit(c, x):
+            return 2 * c
+
+        @qjit(static_argnums=(1,))
+        def wrapper(x, c):
+            return circuit(c, x)
+
+        wrapper(0.5, 0.5)
+        print(wrapper.mlir)
+

I tried adding static_argnums to a non-qnode function and I get the following error:

E           TypeError: TestStaticArguments.test_qnode_nested_not_qnode.<locals>.circuit() got an unexpected keyword argument 'static_argnums'

I think having this feature work might be out of scope, but I think that we will like it at some point. Other than that, I think it looks good!

I think there should be a simple fix for this. Let me check.

erick-xanadu commented 1 month ago

@dime10: @josh146, @mehrdad2m and I had a discussion offline. @josh146 says that if it is too difficult he is happy to leave it as out of scope, although @mehrdad2m may have a fix already :)

mehrdad2m commented 1 month ago

@dime10: @josh146, @mehrdad2m and I had a discussion offline. @josh146 says that if it is too difficult he is happy to leave it as out of scope, although @mehrdad2m may have a fix already :)

My last commit should fix this issue.

dime10 commented 1 month ago

@dime10: @josh146, @mehrdad2m and I had a discussion offline. @josh146 says that if it is too difficult he is happy to leave it as out of scope, although @mehrdad2m may have a fix already :)

My last commit should fix this issue.

Wow nice work! I would have thought since we completely ignore inner qjit decorators the inner static_argnums would be ignored as well, but it seems you managed to make it work :)

There is a bit of a larger issue here about compile options provided in nested qjit functions, I'm guessing for most of them they are generally ignored. The eventual fix for this should be reusing nested compiled functions rather than tracing through them again I think.

mehrdad2m commented 1 month ago

Just correcting the test, I think both work, but this is the last change. Thanks!

Ah sorry, forgot to remove them at the end. thanks for catching this.