PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

Indexed assignment doesn't work with dynamically-shaped arrays #906

Open dime10 opened 1 month ago

dime10 commented 1 month ago

The following program raises an error:

import jax.numpy as jnp
from catalyst import *

@qjit
def f(n: int, m: int):
    x = jnp.ones((n, m), dtype=float)
    y = jnp.ones((n, m), dtype=float)

    @for_loop(0, n, 1, experimental_preserve_dimensions=True)
    def sum_and_multiply(i, x, y):
        x[i] = x[i] + y[i]
        y[i] = x[i] * y[i]
        return x, y

    return sum_and_multiply(x, y)

f(2, 3)
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py:2072, in non_negative_dim(d)
   [2070](jax/_src/core.py:2070) if is_constant_dim(d):
   [2071](jax/_src/core.py:2071)   return max(0, d)
-> [2072](jax/_src/core.py:2072) assert is_symbolic_dim(d)
   [2073](jax/_src/core.py:2073) try:
   [2074](jax/_src/core.py:2074)   d_ge_0 = (d >= 0)

AssertionError:

It does not happen without the indexed assignment.

josh146 commented 1 month ago

Don't you need to use x.at[i].set(x[i] + y[i]) here?

dime10 commented 1 month ago

Good point 😅 but the error happens before that issue kicks in apparently

josh146 commented 1 month ago

ah, so it happens at the 'get value' stage, got it!