tensorflow / probability

Probabilistic reasoning and statistical analysis in TensorFlow
https://www.tensorflow.org/probability/
Apache License 2.0
4.26k stars 1.1k forks source link

[oryx] harvest after grad #1045

Open thisiscam opened 4 years ago

thisiscam commented 4 years ago

Hi tensorflow probability team!

I was using the oryx library that is maintained in this repo, following an issue from JAX https://github.com/google/jax/issues/3936#issuecomment-668954244.

I am wondering if there is a way to harvest/reap with a grad nested inside, and get the gradient values instead of the primal values. e.g.

import oryx
reap = oryx.core.reap
sow = oryx.core.sow

def f(x): 
  return sow(x ** 3, tag="test", name="x3") + 1

print(reap(grad(f),  tag="test"))
# out: {'x3': DeviceArray(1., dtype=float32)}
# how can we get (something similar to) {'x3': DeviceArray(3., dtype=float32)} ?

Will nest help here?

Thanks!

emilyfertig commented 4 years ago

Maybe @sharadmv can help?

sharadmv commented 4 years ago

Currently there's no "officially" supported way to do this though I'd be happy to brainstorm a long-term, supported solution.

A simple workaround is using custom_vjp/jvp, though I see you already noticed you can't use both. In any case you could get either working:

from jax import grad, jvp, custom_vjp, custom_jvp

import oryx
reap = oryx.core.reap
sow = oryx.core.sow

def sow_with_jvp(x, *, name, **kwargs):
  @jax.custom_jvp
  def custom_sow(x):
    return sow(x, name=name, **kwargs)

  @custom_sow.defjvp
  def custom_jvp(primals, tangents):
    x, = primals
    g, = tangents
    g = sow(g, name=f'{name}_jvp', **kwargs, key=x)
    return custom_sow(x), g
  return custom_sow(x)

def sow_with_vjp(x, *, name, **kwargs):
  @jax.custom_vjp
  def custom_sow(x):
    return sow(x, name=name, **kwargs)

  def custom_sow_fwd(x, **kwargs):
    return custom_sow(x), x

  def custom_sow_bwd(x, g):
    g = sow(g, name=f'{name}_vjp', **kwargs, key=x)
    return (g,)
  custom_sow.defvjp(custom_sow_fwd, custom_sow_bwd)
  return custom_sow(x)

This doesn't work completely because of the mutual exclusivity of custom_vjp/vjp and I wasn't able to get jax.jacrev to work because of some apparent constant folding that is happening. In any case this following snippet works.

def f1(x): 
  return sow_with_vjp(x ** 3, tag="test", name="x3") + 1
def f2(x): 
  return sow_with_jvp(x ** 3, tag="test", name="x3") + 1

print(reap(grad(f1), tag='test')(3.)) # ==> {'x3': 27., 'x3_vjp': 1.}
print(reap(lambda x: jax.jvp(f2, (x,), (1.,)), tag='test')(3.)) # ==> {'x3': 27., 'x3_jvp': 27.}
thisiscam commented 3 years ago

@sharadmv Sorry for this year long delay.

I haven't got too much time looking into this since I last posted (this was mostly a curiosity driven investigation).

That being said, I'm still interested in this! So a couple of comments/updates:

  1. I tried your solution with custom_(jvp|vjp), and they indeed work! The reason jacrev didn't work was because of its dependence on vmap, which constant folded the sow due to the batch rule for tie_all_p: https://github.com/tensorflow/probability/blob/master/spinoffs/oryx/oryx/core/primitive.py#L245. Changing this to:

    def _tie_all_batch_rule(batched_args, batch_dims, **params):
    outs = tie_all_p.bind(*batched_args, **params)
    return outs, batch_dims

    fixes the jacrev problem.

  2. I'm really interested in a more general solution, something that works for both jvp and vjp at the same time. I realized at some point that, what I really want for vjp wasn't simply sow-ing its cotangent values. Instead, I was really looking for the same result as in the case for jvp -- basically, I'd like the gradient of the marked intermediate. That is, do you think there's a way to let your example always output x3_jvp, regardless of applying jax.jvp or jax.grad? e.g.

    print(reap(grad(f1), tag='test')(3.)) # ==> {'x3': 27., 'x3_jvp': 27.}
    print(reap(lambda x: jax.jvp(f2, (x,), (1.,)), tag='test')(3.)) # ==> {'x3': 27., 'x3_jvp': 27.}

    The problem I think I'm encountering, is that once vjp invokes jvp, during the backward pass, it seems to drop whatever operations done in the jvp, thereby losing the marked *_jvp values.

thisiscam commented 3 years ago

@sharadmv

Another thing I've been trying is to get more nested jvps to work. The approach I'm taking is to use oryx.core.nest, so that jvps are reflected in a nested dict, where the depth corresponds to the number of times the gradient was taken. Currently, this doesn't work for more than one jvp. Self-contained example:

import jax

from oryx.core import nest, sow, reap

def _sow_with_jvp(x, *, tag, name, mode='strict', key=None):
  """Sow and preserve the the jvp."""

  @jax.custom_jvp
  def custom_sow(x):
    return sow(x, tag=tag, name=name, mode=mode, key=key)

  @custom_sow.defjvp
  def custom_jvp(primals, tangents):
    x, = primals
    g, = tangents
    g = nest(_sow_with_jvp, scope="jvp")(g,
                                         tag=tag,
                                         name=name,
                                         mode=mode,
                                         key=x)
    return custom_sow(x), g

  return custom_sow(x)

def f(x):
  return _sow_with_jvp(x**3, tag="test", name="x3") + 1

f1 = jax.jacfwd(f)
f2 = jax.jacfwd(f1)

print(reap(f1, tag="test")(1.)) # Prints: {'jvp': {'x3': DeviceArray([3.], dtype=float32)}, 'x3': DeviceArray(1., dtype=float32)}

print(reap(f2, tag="test")(1.)) # ValueError: Variable has already been reaped: jvp

The last error looks weird to me. My best guess right now is that the jacfwd calls and the reap call are interleaved somehow, causing a plant-after-reap error?

thisiscam commented 3 years ago

Some more updates:

It appears that the error ValueError: Variable has already been reaped: jvp, is due to nest_p also has a mode parameter, but is set to strict always. I wonder why that is the case? I was envisioning scopes as "namespaces" (as in C++), so only values inside them can conflict, but not the namespaces themselves.

It also appears that this custom_jvp hack doesn't work very well with nested jvps, as I sometimes get custom_jvp can't close over values error. For this reason, I further baked the jvp logic into the sow_p primitive.

def sow_jvp_rule(primals, tangents, *, tag, name, mode, **kwargs):
  val_out = sow(primals, tag=tag, name=name, mode=mode)
  if all(type(tangent) is Zero for tangent in tangents):
    return val_out, Zero.from_value(val_out)
  else:
    tangents = tuple(map(ad.instantiate_zeros, tangents))
    tangents = nest(sow, scope="jvp", mode="clobber")(tangents,
                                                      tag=tag,
                                                      name=name,
                                                      key=primals,
                                                      mode=mode)

    return val_out, tangents

And this appears to work. I will later gather my changes and do a draft PR.

Lastly, I'd like to share a semantics of sow/reap that I've been imagining:

One fundamental question of reap, is how it should compose with other transformations such as jvp, vjp, vmap, jit. If reap is applied before any other transformation, this is less questionable. If reap is applied after some (in general, non-semantics preserving) transformation T, such as jvp, I think a clean semantics is probably to return the sown values vs, along with the transformed sown values T(vs) in a nested scope (perhaps with a mangled named so that it relates to T, but that decision is less important). This way, users would have a predicable way to reasonable about the reaped values.

A problem with this semantics is redundant computation: if all the yield points' values are always returned for all active transformations, that could lead to a huge amount of common subexpressions. So, it seems like perhaps instead of returning vs and T(vs) directly, it might be desirable to return a continuation instead, so that the calculation of these values can be done lazily. So imagine this API (I ignored tag for conciseness):

def f(x): 
  return sow(x ** 3, name="x3") + 1

r1 = reap(f, name="x1")    # DelayedDict(keys={"x1": *})
r1["x1"](2.)               # 8.

r2 = reap(grad(f), name="x1")(2.) # DelayedDict(keys={"x1": *, "grad": {"x1": *} }) 
f2 = r2["grad", "x1"]             # f2 is equivalent to grad(lambda x: x**3) 
f2(2.)                            # 12.

r3 =  reap(jit(vmap(grad(f))), name="x1")([2., 2.])
f3 = r2["grad", "vmap", "x1"]            # f2 is equivalent to vmap(grad(lambda x: x**3)) 
f4 = r2["jit", "grad", "vmap", "x1"]     # f2 is equivalent to jit(vmap(grad(lambda x: x**3)))

Now, I'm not sure about the feasibility of this approach -- my guess is that this needs to be part of jax core, so that every transformation somehow knows about the existence of these sow points, so they can prepare the transformed values accordingly.

WDYT?

sharadmv commented 3 years ago

Wow, thanks for the in-depth investigation!

And this appears to work. I will later gather my changes and do a draft PR.

Awesome, I'd be happy to review it. This is a problem that I have also investigated but haven't yet arrived at a complete solution. I'm happy that you're also interested in it so I'll share my half-baked brainstorm here to add something to the discussion. I believe this approach is similar to the one taken by Dex, so there is some precedent for it.

linearize(reap)

Let's say instead of using the jvp transformation, we're using linearize, which instead of taking in primals, tangents and returning out_primals, out_tangents, takes in primals and returns out_primals and the JVP function that returns output tangents given input tangents. It has a function signature: (a -> b) -> a -> (b, a --o b), where --o indicates a linear function.

Let's focus on reap (dropping tags for readability) and let's also say we'd like this invariant:

reap(linearize(f)) == linearize(reap(f))

Consider the example function:

def f(x):
  sow(x ** 2., name='y')
  return x + 1.

We have:

reap(f)(x):
  return x + 1., dict(y=x ** 2)

linearize(reap(f))(x):
  return (x + 1., dict(y=x ** 2)), lambda t: (t, dict(y=2 * t))

In order to make reap and linearize commute, we'd need:

linearize(f)(x):
  def jvp(t):
    sow(2 * t, name='y')
    return t
  return x + 1, jvp

We can't directly reap(linearize(f)) because it returns a function. But if we reap the returned jvp function, we get the same jvp function as returned by linearize(reap(f)).

In this case, the JVP rule for sow directly sows the tangents with the identical name. Note that we don't get a name collision here like we do when using jvp because linearize has a separate function for computing tangents which doesn't sow 'y' during the forward pass.

It also turns out that linearize(plant) == plant(linearize) as well with this JVP rule.

vjp(reap)

Let's run through the same exercise with vjp.

reap(vjp(f)) == vjp(reap(f))

Consider the example function:

def f(x):
  sow(x ** 2., name='y')
  return x + 1.

We have:

reap(f)(x):
  return x + 1., dict(y=x ** 2)

vjp(reap(f))(x):
  def vjp_fn(ct_x, ct_y):
    return ct_x + 2 * ct_y * x
  return (x + 1., dict(y=x ** 2)), vjp_fn

Now here's the tricky part. In order to make reap and vjp commute, we need to add an argument to the vjp_fn's signature corresponding to the cotangent for 'y'. We can't change the signature of this function as part of a transpose rule though. It turns out, we need use sow again, but we need it to accept a planted value.

vjp(f)(x):
  def vjp_fn(ct_x):
    ct_y = sow(<dummy value>, name='y')
    return ct_x + 2 * ct_y * x
  return x + 1, vjp_fn

Conceptually, this VJP function is meaningless if we don't inject a value for ct_y and when we use it, we need to plant in a value for ct_y.

out_primals, f_vjp = jax.vjp(f)(2.)
x_grad = plant(f_vjp)(dict(y=1.), 1.)

If you run through this exercise with plant(vjp) you should also find that you need a "write" sow in the backwards pass.

Conclusion

My conclusion after this brainstorm is that:

  1. JVP rules are fairly straightforward if we think about jax.linearize rather jax.jvp.
  2. VJP rules only make sense if you think of sow as either being "writes" (aka reap) or "reads" (aka plants). The transpose of a "read" sow is a "write" sow, and vice-versa. The implication is that reap cannot exist w/o plant and vice versa.
thisiscam commented 3 years ago

@sharadmv Thanks for sharing!

And this appears to work.

It appears to work, in that it will return the transposed result of jacfwd for the sown value if I'm using jacfwd (instead of jvp). So the solution isn't very robust. This is due to jacfwd needs to permute the output axes https://github.com/google/jax/blob/master/jax/_src/api.py#L977, but currently harvest does not respect the out_axes parameter of vmap.

More thoughts:

For any of the jac* operations, it should be possible to simply decompose the function and apply chain rule. But this would require a customized jac* function:

So, for a function f(x, y) = g(h(x), y), we can write a custom grad function that does this: f'(x, y) = g'(h(x), y) * h'(x), and in that implementation one can re-sow the h and h' (where h' will use a different scope/name).

This should work regardless of forward or backward pass. Essentially, one can create a set of jac* routines that preserves the sow tags.

In some sense, for jacrev, my intuition is that this is equivalent to materializing a function that returns the sown value.

thisiscam commented 3 years ago

This is due to jacfwd needs to permute the output axes https://github.com/google/jax/blob/master/jax/_src/api.py#L977, but currently harvest does not respect the out_axes parameter of vmap.

To add a simplified example where harvest gave me a "surprising" result:

import jax
import jax.numpy as jnp

import oryx

sow = oryx.core.sow
reap = oryx.core.reap

def f(x):
  return sow(x**3, tag="test", name="test") + 1

M, N = 3, 7

x = jnp.arange((M * N)).reshape(M, N)

v1 = jax.vmap(f, in_axes=0, out_axes=-1)(x)
v2 = reap(jax.vmap(f, in_axes=0, out_axes=-1), tag="test")(x)["test"]

print(v1.shape)  # (7, 3)
print(v2.shape)  # (3, 7)

print(jnp.all(v1 == v2.T + 1))  # Currently this prints `True`

print(jnp.all(v1 == v2 + 1))  # I would argue that this is more natural