Closed erick-xanadu closed 3 months ago
Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc
?
Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc?
I consider it a bug, but I'm looping @rauletorresc as he should have the final say.
Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc?
I consider it a bug, but I'm looping @rauletorresc as he should have the final say.
I would say it is. Could you please rebase it to v0.7.0-rc
branch? Thank you
Just a note to please choose a short but descriptive title when opening a PR :)
Thanks @erick-xanadu, this is great! I'm curious if this would be considered a bug/solves any bugs, and if so should be part of v0.7.0-rc?
I consider it a bug, but I'm looping @rauletorresc as he should have the final say.
I would say it is. Could you please rebase it to
v0.7.0-rc
branch? Thank you
Did it myself :)
Just a note to please choose a short but descriptive title when opening a PR :)
This was done during hackweek as a proof of concept. I'm working on cleaning it.
Something in the merge I think changed the behaviour... Nevermind. It looks like JAX had an issue and _, shape = tree_unflatten(func)
is different from shape = tree_structure(func)
.
Something in the merge I think changed the behaviour... Nevermind. It looks like JAX had an issue and
_, shape = tree_unflatten(func)
is different fromshape = tree_structure(func)
.
Is that a bug? :O
Context:
catalyst.accelerate
should be able tojax.jit
a function of typejax.Partial
. This pattern comes from usingcatalyst.accelerate
withjax.jvp
returning ajax.Partial
.Description of the Change:
jax.Partial
is JAX's pytree compatible re-implementation of Python'spartial
. As a recapUsing
tree_flatten(partial)
will return a list of tree nodes and a shape, just like any call totree_flatten
. Callingtree_unflatten(shape, data)
will reconstructjax.Partial
. This means that during tracing,tree_flatten(partial)
will unflatten to a list of tracers and a shape. We can construct a total function by passing the parameters that would normally be embedded in the Partial.Now this
total
function is what isjax.jit
ed.Benefits: Support for the pattern in #852
Possible Drawbacks: We can generalize this even more as
jax.Partial
can be an argument to a callback. However, we leave that as future work.Related GitHub Issues: Closes #852
This should also be cherry-picked to main.
[sc-66839]