PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
130 stars 29 forks source link

[Frontend] Support capturing dynamically-shaped arrays from outer sco… #890

Closed rauletorresc closed 3 months ago

rauletorresc commented 3 months ago

…pes in Python programs (#830)

Context: Catalyst supports Jax dynamically-shaped arrays. The current version has a notable limitation: body loop programs do not allow mixing captured dynamically-shaped arrays with the argument ones even if they are of the same dimension. This is because we effectively duplicated dimension variables for the loop-body arguments.

As an illustration, the loop body of the below program takes a as an argument array and x as a captured array. The experimental_preserve_dimensions flag has the default value of True

@qjit(abstracted_axes={1: 'n'})
def g(x, y):

  @catalyst.for_loop(0, 10, 1, experimental_preserve_dimensions=True)
  def loop(_, a):
    return a * x

  return jnp.sum(loop(y))

a = jnp.ones([1,3], dtype=float)
b = jnp.ones([1,3], dtype=float)
g(a, b)

Description of the Change:

This PR sets the following semantic of loops, depending on the value of the already-existing experimental_preserve_dimensions flag:

[sc-60521]

Benefits: Since this fixes a bug in the dynamic shape array in v0.7.0, it has been cherry-picked from 'main'.

Possible Drawbacks:

Related GitHub Issues:

rauletorresc commented 3 months ago

@rmoyard The original author of this PR was Sergei, I've just cherry-picked it from main, where it was already merged :(