Context: Catalyst supports Jax dynamically-shaped arrays. The current version has a notable limitation: body loop programs do not allow mixing captured dynamically-shaped arrays with the argument ones even if they are of the same dimension. This is because we effectively duplicated dimension variables for the loop-body arguments.
As an illustration, the loop body of the below program takes a as an argument array and x as a captured array. The
experimental_preserve_dimensions flag has the default value of True
@qjit(abstracted_axes={1: 'n'})
def g(x, y):
@catalyst.for_loop(0, 10, 1, experimental_preserve_dimensions=True)
def loop(_, a):
return a * x
return jnp.sum(loop(y))
a = jnp.ones([1,3], dtype=float)
b = jnp.ones([1,3], dtype=float)
g(a, b)
Description of the Change:
This PR sets the following semantic of loops, depending on the value of the already-existing experimental_preserve_dimensions flag:
True (the default): all dynamic dimension variables will be handled as Jaxpr constants. So
Mixing argument and captured arrays is possible.
No dimension modification is allowed inside loops.
False would handle all the dimensions as Jaxpr implicit arguments.
Mixing argument and captured arrays is not possible.
Dimension modification within the loops are allowed.
[sc-60521]
Benefits: Since this fixes a bug in the dynamic shape array in v0.7.0, it has been cherry-picked from 'main'.
…pes in Python programs (#830)
Context: Catalyst supports Jax dynamically-shaped arrays. The current version has a notable limitation: body loop programs do not allow mixing captured dynamically-shaped arrays with the argument ones even if they are of the same dimension. This is because we effectively duplicated dimension variables for the loop-body arguments.
As an illustration, the loop body of the below program takes
a
as an argument array andx
as a captured array. Theexperimental_preserve_dimensions
flag has the default value ofTrue
Description of the Change:
This PR sets the following semantic of loops, depending on the value of the already-existing
experimental_preserve_dimensions
flag:True
(the default): all dynamic dimension variables will be handled as Jaxpr constants. SoFalse
would handle all the dimensions as Jaxpr implicit arguments.[sc-60521]
Benefits: Since this fixes a bug in the dynamic shape array in v0.7.0, it has been cherry-picked from 'main'.
Possible Drawbacks:
Related GitHub Issues: