PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
101 stars 26 forks source link

Autograph fails when iterating over non-Jax Pytree types #896

Open tzunghanjuang opened 4 days ago

tzunghanjuang commented 4 days ago

Issue description

Using @qjit(autograph=True) decorator to loop over Pytree types that are not jax.numpy (like python builtin list, tuple, and numpy array) fails. No conversion will be applied. Converting non-jax array types into jax ones manually will pass.

Source code and tracebacks

Source code:

from catalyst import qjit

def updateList(x):
    return [x[0]+1, x[1]+2]

@qjit(autograph=True)
def fn(x):
    for i in range(4):
        x = updateList(x)
    return x

fn([1, 2])

Trace:

[~/catalyst/frontend/catalyst/autograph/ag_primitives.py:347]: UserWarning: Tracing of an AutoGraph converted for loop failed with an exception:
  AutoGraphError:    The variable 'x' was initialized with type <class 'list'>, which is not compatible with JAX. Typically, this is the case for non-numeric values.
    You may still use such a variable as a constant inside a loop, but it cannot be updated from one iteration to the next, or accessed outside the loop scope if it was defined inside of it.

The error ocurred within the body of the following for loop statement:
  File "/tmp/ipykernel_165289/1920264522.py", line 8, in fn
    for i in range(4):

If you intended for the conversion to happen, make sure that the (now dynamic) loop variable is not used in tracing-incompatible ways, for instance by indexing a Python list with it. In that case, the list should be wrapped into an array.
To understand different types of JAX tracing errors, please refer to the guide at: https://jax.readthedocs.io/en/latest/errors.html

If you did not intend for the conversion to happen, you may safely ignore this warning.
  warnings.warn(

The error is triggered here: https://github.com/PennyLaneAI/catalyst/blob/140cbd31262b9c4fd70c760c1ac4efecc487cf3f/frontend/catalyst/autograph/ag_primitives.py#L159-L169

dime10 commented 4 days ago

Thanks for reporting this! I think the missing element here is allowing Pytree types through the autograph process, arbitrary types would be out of scope.