Open thisiscam opened 4 years ago
Maybe @sharadmv can help?
Currently there's no "officially" supported way to do this though I'd be happy to brainstorm a long-term, supported solution.
A simple workaround is using custom_vjp/jvp
, though I see you already noticed you can't use both. In any case you could get either working:
from jax import grad, jvp, custom_vjp, custom_jvp
import oryx
reap = oryx.core.reap
sow = oryx.core.sow
def sow_with_jvp(x, *, name, **kwargs):
@jax.custom_jvp
def custom_sow(x):
return sow(x, name=name, **kwargs)
@custom_sow.defjvp
def custom_jvp(primals, tangents):
x, = primals
g, = tangents
g = sow(g, name=f'{name}_jvp', **kwargs, key=x)
return custom_sow(x), g
return custom_sow(x)
def sow_with_vjp(x, *, name, **kwargs):
@jax.custom_vjp
def custom_sow(x):
return sow(x, name=name, **kwargs)
def custom_sow_fwd(x, **kwargs):
return custom_sow(x), x
def custom_sow_bwd(x, g):
g = sow(g, name=f'{name}_vjp', **kwargs, key=x)
return (g,)
custom_sow.defvjp(custom_sow_fwd, custom_sow_bwd)
return custom_sow(x)
This doesn't work completely because of the mutual exclusivity of custom_vjp/vjp
and I wasn't able to get jax.jacrev
to work because of some apparent constant folding that is happening. In any case this following snippet works.
def f1(x):
return sow_with_vjp(x ** 3, tag="test", name="x3") + 1
def f2(x):
return sow_with_jvp(x ** 3, tag="test", name="x3") + 1
print(reap(grad(f1), tag='test')(3.)) # ==> {'x3': 27., 'x3_vjp': 1.}
print(reap(lambda x: jax.jvp(f2, (x,), (1.,)), tag='test')(3.)) # ==> {'x3': 27., 'x3_jvp': 27.}
@sharadmv Sorry for this year long delay.
I haven't got too much time looking into this since I last posted (this was mostly a curiosity driven investigation).
That being said, I'm still interested in this! So a couple of comments/updates:
I tried your solution with custom_(jvp|vjp)
, and they indeed work! The reason jacrev
didn't work was because of its dependence on vmap
, which constant folded the sow due to the batch rule for tie_all_p
: https://github.com/tensorflow/probability/blob/master/spinoffs/oryx/oryx/core/primitive.py#L245. Changing this to:
def _tie_all_batch_rule(batched_args, batch_dims, **params):
outs = tie_all_p.bind(*batched_args, **params)
return outs, batch_dims
fixes the jacrev
problem.
I'm really interested in a more general solution, something that works for both jvp
and vjp
at the same time.
I realized at some point that, what I really want for vjp
wasn't simply sow-ing its cotangent values. Instead, I was really looking for the same result as in the case for jvp
-- basically, I'd like the gradient of the marked intermediate.
That is, do you think there's a way to let your example always output x3_jvp, regardless of applying jax.jvp
or jax.grad
? e.g.
print(reap(grad(f1), tag='test')(3.)) # ==> {'x3': 27., 'x3_jvp': 27.}
print(reap(lambda x: jax.jvp(f2, (x,), (1.,)), tag='test')(3.)) # ==> {'x3': 27., 'x3_jvp': 27.}
The problem I think I'm encountering, is that once vjp
invokes jvp
, during the backward pass, it seems to drop whatever operations done in the jvp
, thereby losing the marked *_jvp
values.
@sharadmv
Another thing I've been trying is to get more nested jvp
s to work.
The approach I'm taking is to use oryx.core.nest
, so that jvps are reflected in a nested dict, where the depth corresponds to the number of times the gradient was taken.
Currently, this doesn't work for more than one jvp.
Self-contained example:
import jax
from oryx.core import nest, sow, reap
def _sow_with_jvp(x, *, tag, name, mode='strict', key=None):
"""Sow and preserve the the jvp."""
@jax.custom_jvp
def custom_sow(x):
return sow(x, tag=tag, name=name, mode=mode, key=key)
@custom_sow.defjvp
def custom_jvp(primals, tangents):
x, = primals
g, = tangents
g = nest(_sow_with_jvp, scope="jvp")(g,
tag=tag,
name=name,
mode=mode,
key=x)
return custom_sow(x), g
return custom_sow(x)
def f(x):
return _sow_with_jvp(x**3, tag="test", name="x3") + 1
f1 = jax.jacfwd(f)
f2 = jax.jacfwd(f1)
print(reap(f1, tag="test")(1.)) # Prints: {'jvp': {'x3': DeviceArray([3.], dtype=float32)}, 'x3': DeviceArray(1., dtype=float32)}
print(reap(f2, tag="test")(1.)) # ValueError: Variable has already been reaped: jvp
The last error looks weird to me. My best guess right now is that the jacfwd
calls and the reap
call are interleaved somehow, causing a plant-after-reap error?
Some more updates:
It appears that the error ValueError: Variable has already been reaped: jvp
, is due to nest_p
also has a mode
parameter, but is set to strict
always. I wonder why that is the case? I was envisioning scopes as "namespaces" (as in C++), so only values inside them can conflict, but not the namespaces themselves.
It also appears that this custom_jvp
hack doesn't work very well with nested jvp
s, as I sometimes get custom_jvp can't close over values error
. For this reason, I further baked the jvp logic into the sow_p
primitive.
def sow_jvp_rule(primals, tangents, *, tag, name, mode, **kwargs):
val_out = sow(primals, tag=tag, name=name, mode=mode)
if all(type(tangent) is Zero for tangent in tangents):
return val_out, Zero.from_value(val_out)
else:
tangents = tuple(map(ad.instantiate_zeros, tangents))
tangents = nest(sow, scope="jvp", mode="clobber")(tangents,
tag=tag,
name=name,
key=primals,
mode=mode)
return val_out, tangents
And this appears to work. I will later gather my changes and do a draft PR.
Lastly, I'd like to share a semantics of sow/reap
that I've been imagining:
One fundamental question of reap
, is how it should compose with other transformations such as jvp
, vjp
, vmap
, jit
.
If reap
is applied before any other transformation, this is less questionable.
If reap
is applied after some (in general, non-semantics preserving) transformation T
, such as jvp
, I think a clean semantics is probably to return the sown values vs
, along with the transformed sown values T(vs)
in a nested scope (perhaps with a mangled named so that it relates to T
, but that decision is less important). This way, users would have a predicable way to reasonable about the reaped values.
A problem with this semantics is redundant computation: if all the yield points' values are always returned for all active transformations, that could lead to a huge amount of common subexpressions. So, it seems like perhaps instead of returning vs
and T(vs)
directly, it might be desirable to return a continuation instead, so that the calculation of these values can be done lazily.
So imagine this API (I ignored tag
for conciseness):
def f(x):
return sow(x ** 3, name="x3") + 1
r1 = reap(f, name="x1") # DelayedDict(keys={"x1": *})
r1["x1"](2.) # 8.
r2 = reap(grad(f), name="x1")(2.) # DelayedDict(keys={"x1": *, "grad": {"x1": *} })
f2 = r2["grad", "x1"] # f2 is equivalent to grad(lambda x: x**3)
f2(2.) # 12.
r3 = reap(jit(vmap(grad(f))), name="x1")([2., 2.])
f3 = r2["grad", "vmap", "x1"] # f2 is equivalent to vmap(grad(lambda x: x**3))
f4 = r2["jit", "grad", "vmap", "x1"] # f2 is equivalent to jit(vmap(grad(lambda x: x**3)))
Now, I'm not sure about the feasibility of this approach -- my guess is that this needs to be part of jax core, so that every transformation somehow knows about the existence of these sow
points, so they can prepare the transformed values accordingly.
WDYT?
Wow, thanks for the in-depth investigation!
And this appears to work. I will later gather my changes and do a draft PR.
Awesome, I'd be happy to review it. This is a problem that I have also investigated but haven't yet arrived at a complete solution. I'm happy that you're also interested in it so I'll share my half-baked brainstorm here to add something to the discussion. I believe this approach is similar to the one taken by Dex, so there is some precedent for it.
linearize(reap)
Let's say instead of using the jvp
transformation, we're using linearize
, which instead of taking in primals, tangents
and returning out_primals, out_tangents
, takes in primals
and returns out_primals
and the JVP function that returns output tangents given input tangents. It has a function signature: (a -> b) -> a -> (b, a --o b)
, where --o
indicates a linear function.
Let's focus on reap
(dropping tags for readability) and let's also say we'd like this invariant:
reap(linearize(f)) == linearize(reap(f))
Consider the example function:
def f(x):
sow(x ** 2., name='y')
return x + 1.
We have:
reap(f)(x):
return x + 1., dict(y=x ** 2)
linearize(reap(f))(x):
return (x + 1., dict(y=x ** 2)), lambda t: (t, dict(y=2 * t))
In order to make reap
and linearize
commute, we'd need:
linearize(f)(x):
def jvp(t):
sow(2 * t, name='y')
return t
return x + 1, jvp
We can't directly reap(linearize(f))
because it returns a function. But if we reap
the returned jvp
function, we get the same jvp
function as returned by linearize(reap(f))
.
In this case, the JVP rule for sow
directly sow
s the tangents with the identical name. Note that we don't get a name collision here like we do when using jvp
because linearize
has a separate function for computing tangents which doesn't sow 'y'
during the forward pass.
It also turns out that linearize(plant) == plant(linearize)
as well with this JVP rule.
vjp(reap)
Let's run through the same exercise with vjp
.
reap(vjp(f)) == vjp(reap(f))
Consider the example function:
def f(x):
sow(x ** 2., name='y')
return x + 1.
We have:
reap(f)(x):
return x + 1., dict(y=x ** 2)
vjp(reap(f))(x):
def vjp_fn(ct_x, ct_y):
return ct_x + 2 * ct_y * x
return (x + 1., dict(y=x ** 2)), vjp_fn
Now here's the tricky part. In order to make reap
and vjp
commute, we need to add an argument to the vjp_fn
's signature corresponding to the cotangent for 'y'
. We can't change the signature of this function as part of a transpose rule though. It turns out, we need use sow
again, but we need it to accept a planted value.
vjp(f)(x):
def vjp_fn(ct_x):
ct_y = sow(<dummy value>, name='y')
return ct_x + 2 * ct_y * x
return x + 1, vjp_fn
Conceptually, this VJP function is meaningless if we don't inject a value for ct_y
and when we use it, we need to plant
in a value for ct_y
.
out_primals, f_vjp = jax.vjp(f)(2.)
x_grad = plant(f_vjp)(dict(y=1.), 1.)
If you run through this exercise with plant(vjp)
you should also find that you need a "write" sow in the backwards pass.
My conclusion after this brainstorm is that:
jax.linearize
rather jax.jvp
.sow
as either being "writes" (aka reap) or "reads" (aka plants). The transpose of a "read" sow is a "write" sow, and vice-versa. The implication is that reap
cannot exist w/o plant
and vice versa.@sharadmv Thanks for sharing!
And this appears to work.
It appears to work, in that it will return the transposed result of jacfwd
for the sown value if I'm using jacfwd
(instead of jvp
). So the solution isn't very robust.
This is due to jacfwd
needs to permute the output axes https://github.com/google/jax/blob/master/jax/_src/api.py#L977, but currently harvest
does not respect the out_axes
parameter of vmap
.
More thoughts:
For any of the jac*
operations, it should be possible to simply decompose the function and apply chain rule. But this would require a customized jac*
function:
So, for a function f(x, y) = g(h(x), y)
, we can write a custom grad
function that does this:
f'(x, y) = g'(h(x), y) * h'(x)
, and in that implementation one can re-sow the h
and h'
(where h'
will use a different scope/name).
This should work regardless of forward or backward pass. Essentially, one can create a set of jac*
routines that preserves the sow tags.
In some sense, for jacrev
, my intuition is that this is equivalent to materializing a function that returns the sown value.
This is due to jacfwd needs to permute the output axes https://github.com/google/jax/blob/master/jax/_src/api.py#L977, but currently harvest does not respect the out_axes parameter of vmap.
To add a simplified example where harvest
gave me a "surprising" result:
import jax
import jax.numpy as jnp
import oryx
sow = oryx.core.sow
reap = oryx.core.reap
def f(x):
return sow(x**3, tag="test", name="test") + 1
M, N = 3, 7
x = jnp.arange((M * N)).reshape(M, N)
v1 = jax.vmap(f, in_axes=0, out_axes=-1)(x)
v2 = reap(jax.vmap(f, in_axes=0, out_axes=-1), tag="test")(x)["test"]
print(v1.shape) # (7, 3)
print(v2.shape) # (3, 7)
print(jnp.all(v1 == v2.T + 1)) # Currently this prints `True`
print(jnp.all(v1 == v2 + 1)) # I would argue that this is more natural
Hi tensorflow probability team!
I was using the oryx library that is maintained in this repo, following an issue from JAX https://github.com/google/jax/issues/3936#issuecomment-668954244.
I am wondering if there is a way to harvest/reap with a grad nested inside, and get the gradient values instead of the primal values. e.g.
Will nest help here?
Thanks!