PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
130 stars 29 forks source link

[BUG] `value_and_grad` fails on non-scalar inputs #841

Closed paul0403 closed 3 months ago

paul0403 commented 3 months ago

Expected:

@qjit
def g(vec):
   prod = jnp.array([30, 40])*vec
   return prod[0]+prod[1]

x = jnp.array([1., 1.])
#result = qjit(catalyst.value_and_grad(g))(x)
result = qjit(catalyst.grad(g))(x)
print(result)

>>>
[30. 40.]

but got:

@qjit
def g(vec):
   prod = jnp.array([30, 40])*vec
   return prod[0]+prod[1]

x = jnp.array([1., 1.])
result = qjit(catalyst.value_and_grad(g))(x)
#result = qjit(catalyst.grad(g))(x)
print(result)

>>>
Traceback (most recent call last):
  File "/home/paul.wang/small_playgrounds_dump/grad_method.py", line 85, in <module>
    result = qjit(catalyst.value_and_grad(g))(x)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 514, in __call__
    requires_promotion = self.jit_compile(args)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 585, in jit_compile
    self.jaxpr, self.out_type, self.out_treedef, self.c_sig = self.capture(args)
  File "/home/paul.wang/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 643, in capture
    jaxpr, out_type, treedef = trace_to_jaxpr(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jax_tracer.py", line 530, in trace_to_jaxpr
    jaxpr, out_type, out_treedef = make_jaxpr2(func, **make_jaxpr_kwargs)(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jax_extras/tracing.py", line 530, in make_jaxpr_f
    jaxpr, out_type, consts = trace_to_jaxpr_dynamic2(f)
  File "/home/paul.wang/catalyst/frontend/catalyst/api_extensions/differentiation.py", line 608, in __call__
    results = value_and_grad_p.bind(
jax._src.source_info_util.JaxStackTraceBeforeTransformation: TypeError: ConstantOp.__init__() missing 1 required positional argument: 'value'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/paul.wang/small_playgrounds_dump/grad_method.py", line 85, in <module>
    result = qjit(catalyst.value_and_grad(g))(x)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 514, in __call__
    requires_promotion = self.jit_compile(args)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 587, in jit_compile
    self.mlir_module, self.mlir = self.generate_ir()
  File "/home/paul.wang/catalyst/frontend/catalyst/debug/instruments.py", line 143, in wrapper
    return fn(*args, **kwargs)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jit.py", line 658, in generate_ir
    mlir_module, ctx = lower_jaxpr_to_mlir(self.jaxpr, self.__name__)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jax_tracer.py", line 556, in lower_jaxpr_to_mlir
    mlir_module, ctx = jaxpr_to_mlir(func_name, jaxpr)
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jax_extras/lowering.py", line 72, in jaxpr_to_mlir
    module, context = custom_lower_jaxpr_to_module(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/pennylane/logging/decorators.py", line 61, in wrapper_entry
    return func(*args, **kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jax_extras/lowering.py", line 141, in custom_lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1301, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1494, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/home/paul.wang/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1602, in lower_per_platform
    return kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/home/paul.wang/catalyst/frontend/catalyst/jax_primitives.py", line 490, in _value_and_grad_lowering
    constants = [
  File "/home/paul.wang/catalyst/frontend/catalyst/jax_primitives.py", line 491, in <listcomp>
    ConstantOp(ir.DenseElementsAttr.get(np.asarray(const))).results for const in jaxpr.consts
TypeError: ConstantOp.__init__() missing 1 required positional argument: 'value'
josh146 commented 3 months ago

Nice catch @paul0403! I wonder if this is easy/straightforward or not to fix, but I guess it requires some digging first.

paul0403 commented 3 months ago

The culprit is when lowering value_and_grad to mlir, the shape of the grad is not computed, and only the type is passed in: https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/jax_primitives.py#L537

(compare this with grad lowering, where the shape is calculated: https://github.com/PennyLaneAI/catalyst/blob/main/frontend/catalyst/jax_primitives.py#L470)

I am fixing this now.

paul0403 commented 3 months ago

manually closing since the PR tracks release branch instead of main