PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
130 stars 29 forks source link

Support `jax.Partial` as input to `catalyst.accelerate`. #882

Closed erick-xanadu closed 3 months ago

erick-xanadu commented 3 months ago

Context: catalyst.accelerate should be able to jax.jit a function of type jax.Partial. This pattern comes from using catalyst.accelerate with jax.jvp returning a jax.Partial.

Description of the Change: jax.Partial is JAX's pytree compatible re-implementation of Python's partial. As a recap

Return a new partial object which when called will behave like func called with the positional arguments args and keyword arguments keywords.

Using tree_flatten(partial) will return a list of tree nodes and a shape, just like any call to tree_flatten. Calling tree_unflatten(shape, data) will reconstruct jax.Partial. This means that during tracing, tree_flatten(partial) will unflatten to a list of tracers and a shape. We can construct a total function by passing the parameters that would normally be embedded in the Partial.

def total(context, *args, **kwargs): # This is the expected signature at runtime.
  _traced_context, shape = tree_flatten(partial)
  new_partial = tree_unflatten(context, shape)
  return new_partial(*args, **kwargs)

Now this total function is what is jax.jited.

Benefits: Support for the pattern in #852

Possible Drawbacks: We can generalize this even more as jax.Partial can be an argument to a callback. However, we leave that as future work.

Related GitHub Issues: Closes #852

This should also be cherry-picked to main.

[sc-66839]

josh146 commented 3 months ago

Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc?

erick-xanadu commented 3 months ago

Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc?

I consider it a bug, but I'm looping @rauletorresc as he should have the final say.

rauletorresc commented 3 months ago

Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc?

I consider it a bug, but I'm looping @rauletorresc as he should have the final say.

I would say it is. Could you please rebase it to v0.7.0-rc branch? Thank you

dime10 commented 3 months ago

Just a note to please choose a short but descriptive title when opening a PR :)

rauletorresc commented 3 months ago

Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc?

I consider it a bug, but I'm looping @rauletorresc as he should have the final say.

I would say it is. Could you please rebase it to v0.7.0-rc branch? Thank you

Did it myself :)

erick-xanadu commented 3 months ago

Just a note to please choose a short but descriptive title when opening a PR :)

This was done during hackweek as a proof of concept. I'm working on cleaning it.

erick-xanadu commented 3 months ago

Something in the merge I think changed the behaviour... Nevermind. It looks like JAX had an issue and _, shape = tree_unflatten(func) is different from shape = tree_structure(func).

dime10 commented 3 months ago

Something in the merge I think changed the behaviour... Nevermind. It looks like JAX had an issue and _, shape = tree_unflatten(func) is different from shape = tree_structure(func).

Is that a bug? :O