jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

inside pmap, grad(lambda x: psum(loss(x))) inconsistent with jvp(lambda x: psum(loss(x)) #3970

Open sharadmv opened 4 years ago

sharadmv commented 4 years ago

Running this on a machine with >=2 devices:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax

def loss(w, d):
  return w * d

def local(w, data):
  def local_loss(w):
    return vmap(lambda d: loss(w, d))(data).sum()
  return local_loss(w), grad(local_loss)(w)
print(local(1., jnp.arange(2.))) # ==> (1., 1.)

def distributed(w, data):
  def agg_loss(w):
    return lax.psum(loss(w, data), 'batch')
  return agg_loss(w), grad(agg_loss)(w)
print(pmap(distributed, in_axes=(None, 0), axis_name='batch')(1., jnp.arange(2.))) # ==> ([1., 1.], [0., 2.])

The losses are correct but gradients are incorrect across each shard.

What should be the intended behavior here?

mattjj commented 4 years ago

This is confusing! It seems to be an issue about how we define transposes of SPMD functions (i.e. it's not just a dumb software bug). No conclusions yet, but wanted to dump some thoughts here.

Here are some test programs:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax

def distributed(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]

def distributed(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]

def distributed(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]

The jaxpr being transposed in the last one is

{ lambda c ; b.
  let d = mul b c
      e = psum[ axis_index_groups=None
                axis_name=batch ] d
  in (e,) }

Operationally, the transpose evaluation proceeds like this: grad feeds in a cotangent for e of 1.0. Then we compute a cotangent for d of 2.0, reflecting the fact that if we perturb the value of d by epsilon (on each replica) then the value of e changes by 2 epsilon. Then we compute a cotangent for b by multiplying by 0. on the first replica and 1. on the second replica, leading to the final result [0. 2.].

Beyond the numbers, we have a symmetry issue. The first two versions produce a symmetric-across-replicas result, which makes sense because agg_loss's last operation is a psum and so its output must be symmetric across replicas. But the reverse-mode version can't produce a symmetric result because it multiplies by the mapped value data at the end. (The symmetric result makes the most sense because agg_loss always produces a symmetric result given a symmetric input.)

Hmmmm...

jekbradbury commented 4 years ago

IMO, we need to keep track of replication to do a good job here. I see two reasonable behaviors:

Behavior A: grad of a replicated thing throws an error, and grad of a non-replicated thing computes a separate grad for each thread Behavior B: grad of a non-replicated thing throws an error (the idea being that grad requires a scalar function, and an SPMD scalar function isn't a scalar function), and grad of a replicated thing computes a single (replicated) summed grad

Or I guess both could work?

("replicated" means guaranteed-symmetric, out_axes=None is allowed)

mattjj commented 4 years ago

Can you define 'replicated thing' and 'non-replicated thing'? Do you mean a function (agg_loss in this case), and if so what's it mean to be replicated (maybe: has any collectives in it) ?

jekbradbury commented 4 years ago

By "replicated thing" I mean a guaranteed-symmetric value, i.e. one where out_axes=None would be allowed if it were a return value of the map. In terms of functions, I mean those whose return values are replicated.

mattjj commented 4 years ago

Thanks! One possible desideratum is not to do something different on guaranteed-symmetric values versus just-happens-to-be-symmetric values. Notice that in this example if data just happens to be symmetric, we get the right answer. Also, if we noticed (say in backward_pass) that the function to be transposed is symmetric (in the sense that when it's given a symmetric-acoss-replicas input it produces a symmetric output), and in that case always symmetrized the cotangents of the jaxpr inputs, then we'd compute the right answer in both guaranteed-symmetric (new!) and just-happens-to-be-symmetric (already got that one) cases.

EDIT: I think that, minus the "throws an error" parts, this may be what you meant by "both". Is that right?

mattjj commented 4 years ago

I'm starting to agree with James that we might need to track spmd-symmetry in jaxprs, and reverse mode should do something interesting with that information. It's tricky: if we have a jaxpr to transpose with both symmetric and non-symmetric outputs, we may need to run two backward passes, one for the symmetric outputs (where we symmetrize the resulting input cotangnents) and one for the non-symmetric outputs (where we don't).

john-m-jumper commented 4 years ago

I have thought about this issue a bit in another context and I think you need the notion/type of a SPMD-global scalar to make grad make sense with internal pmap. The API difficulty with pmap is that it has no way to return something of rank < 1 so semantically the program never has a function on which grad should be defined since the output is never a scalar. Have a ReplicatedArray be a tensor in some contexts and a scalar in others seems like it could lead to semantic problems.

jekbradbury commented 4 years ago

I think we shouldn't necessarily expect invariants that are true in ordinary code to also hold for SPMD code; we often assume that SPMD code will behave like a "batch of (independent) programs", but the presence of collectives means that can be a bad analogy! Instead, I like to think of SPMD programming as a different perspective on bulk array programming (the named axis is a real axis that just doesn't have a position, and collectives are just bulk primitives, like reductions, that operate on that axis).

That means that a scalar -> scalar SPMD program is not actually a scalar -> scalar program, and nor is it a batch of scalar -> scalar programs. It's a vector -> vector program (that acts elementwise/has a diagonal Jacobian if there aren't any collectives, and doesn't if there are). This is a particularly important distinction for autodiff, since a single application of forward-mode autodiff or finite differences can only give the derivative of a scalar -> function, while a single application of reverse-mode autodiff can only give the derivative of a -> scalar function (with an exception in both cases if the Jacobian is known to be diagonal).

In bulk array terms, Matt's three examples are:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax, jvp

def f_jvp(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return jvp(agg_loss, (w,), (1.,))[1]
print(f_jvp(jnp.ones(2), jnp.arange(2)))
# prints 1.0

def f_fd(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(f_fd(jnp.ones(2), jnp.arange(2)))
# prints 1.0000467

def f_grad(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return grad(agg_loss)(w)
print(f_grad(jnp.ones(2), jnp.arange(2)))
# prints [0. 1.]

These results are a little less surprising! Here agg_loss is an ordinary vector -> scalar function and f_grad is computing the gradient, while f_jvp and f_fd are both computing something else (the product of the Jacobian with a vector of ones).

In this bulk array context, another way to compute that same Jacobian-vector product is to take the grad of a scalar -> scalar function that broadcasts its input (expressed with the same agg_loss Python code through NumPy rank polymorphism/implicit broadcasting):

def f_grad_bcast(w, data):
  def agg_loss(w):
    return jnp.sum(w * data, 0)
  return grad(agg_loss)(w)
print(f_grad_bcast(1., jnp.arange(2)))
# prints 1.0

Here the forward pass performs a broadcast, multiply, then sum, so the transpose performs a broadcast, multiply, then sum too (broadcast is fan-out and its transpose is fan-in-sum). How to make sure the transpose contains a sum when the forward pass actually needed to broadcast, but doesn't when it didn't, is the implicit broadcasting problem in autodiff systems for dynamically-ranked languages like TF graphs and TorchScript.

My claim is that all of these concerns apply in the SPMD case, too. An SPMD program cannot run except with all of its named axes bound by corresponding maps, and the resulting program is a bulk array program with well-defined semantics. We need to treat those semantics as the source of truth for thinking about SPMD programs, rather than intuitions about batches of independent programs, because those intuitions break down when collectives are involved.

In bulk array programs, a particular logical axis can either be present or absent in each of the arguments and return values. In the context of bulk array programs where that axis was created by mapping an SPMD axis using JAX's map APIs, this corresponds to whether the in_axes/out_axes setting for each argument/return is an integer or None. Here's what I think that means for the SPMD versions of our examples:

import jax.numpy as jnp
from jax import vmap, pmap, grad, lax, jvp

def f_jvp(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f_jvp, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]

def f_fd(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(f_fd, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]

def f_grad(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f_grad, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# should print [0. 1.]

def f_grad_bcast(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f_grad_bcast, axis_name='batch', in_axes=(None, 0))(1., jnp.arange(2)))
# should print [1. 1.], or just 1.0 with out_axes=None

(That last one is Sharad's code.)

We don't currently produce the desired output in the last case, because we include a spurious psum at the beginning of the transpose (from [1. 1.] to [2. 2.]) and don't include a necessary one at the end. There might be multiple ways to fix the system and get the right answer, but my understanding is that the best way forward looks something like this:

First, we should distinguish between "scalars" and "batched scalars" with respect to a particular SPMD axis, and introduce a pbroadcast to take a scalar to a batched scalar with the same value (while psum should take a batched scalar and produce a scalar). Then, the transpose of pbroadcast should be psum and vice versa, and the cotangent of a scalar should always be a scalar and the cotangent of a batched scalar should always be a batched scalar. This means that (just like in the rank-polymorphic bulk case!) we should insert a pbroadcast when a scalar needs to be promoted to a batched scalar (as in w * data) and the backward pass should have a psum there.

Ideally we would also match the bulk-array error semantics too, and constrain the return value of the function passed to grad to be a scalar, not a batched scalar; we should also support out_axes=None in maps so that we can return these scalars out to the bulk world and make the semantics less obscure.

mattjj commented 4 years ago

@jekbradbury that's brilliant! Let's do it.

mattjj commented 4 years ago

I changed the issue title because I think the current (i.e. at HEAD) semantics make sense, but can make grad(f)(x) and jvp(f, (x,), (1.,)) differ for scalar-input scalar-output f when inside a pmap (as in my comment above), which seems worth revising. The current semantics just take a different correspondence to bulk array programs than the one in @jekbradbury's previous comment.

The original intention with psum was to correspond to a bulk array operation that included broadcasting, not just a reduce-sum as in James's comment above:

pmap(lambda x: psum(x, 'i'), axis_name='i')
==
lambda x: lax.broadcast(x.sum(0), (num_devices,))

So given this pmap code:

def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))

we should expect it to behave like this bulk array version:

import jax.numpy as jnp
from jax import grad
from jax import lax

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))
  return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))

But that code triggers an error, because in api.py we check to ensure grad is only being applied to scalar-output functions. Just for defining semantics, let's sidestep that error check:

import jax.numpy as jnp
from jax import vjp
from jax import lax

def grad(f):
  def gradfun(x):
    ans, f_vjp = vjp(f, x)
    x_bar, = f_vjp(jnp.ones_like(ans))
    return x_bar
  return gradfun

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))  # notice the broadcast!
  return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]

These SPMD semantics explain the reverse-mode answer in my first comment above.

The trouble with this definition, though, is that it can make grad disagree with jvp and numerical differences on what appears locally to be a scalar-input scalar-output function when inside a pmap, as in the examples in my first comment. I want to write out an explanation for what's going on, but I've got to step away for a moment and wanted to send this comment first. To be continued!

In any case, I think @jekbradbury 's proposal for revising the semantics is likely to be better. I just want to pin down both the old and new semantics as best we can.

mattjj commented 4 years ago

Okay, back to it!

Notice the Jacobian of the agg_loss function written with the broadcast is [[0, 1], [0, 1]]:

import jax.numpy as jnp
from jax import lax
from jax import jacfwd

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(0), (2,))
  return jacfwd(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [[0. 1.]
#  [0. 1.]]

So, while keeping the current (at HEAD) bulk array definition of psum as a reduce-sum followed by a broadcast, the SPMD AD semantics is consistent so long as we take grad to mean "compute the VJP against a ones vector broadcast along all named axes":

import jax.numpy as jnp
from jax import vjp, jvp, pmap, grad
from jax import lax

### reverse-mode

# At HEAD, we define this SPMD program:
def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [0. 2.]

# To mean the same as this bulk array vjp-with-ones program (`grad` is always
# defined as vjp-with-ones plus an error check for scalar outputs that we don't
# include in the definition of SPMD semantics):
def grad2(f):
  def gradfun(x):
    ans, f_vjp = vjp(f, x)
    x_bar, = f_vjp(jnp.ones_like(ans))
    return x_bar
  return gradfun

def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
  return grad2(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [0. 2.]

# ### forward-mode

# At HEAD, we define this SPMD program:
def f(w, data):
  def agg_loss(w):
    return lax.psum(w * data, 'batch')
  return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [1. 1.]

# To mean the same as this bulk array jvp-with-ones program:
def f(w, data):
  def agg_loss(w):
    return lax.broadcast((w * data).sum(), (2,))  # bulk array version of psum
  return jvp(agg_loss, (w,), (jnp.ones_like(w),))[1]
print(f(jnp.ones(2), jnp.arange(2)))
# [1. 1.]

(In Autograd, like TF today, we used to define grad as vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn't make that check work with SPMD functions, in the sense that grad will in effect happily allow broadcasting along named mapped axes!)

If the semantics at HEAD are self-consistent, except the error semantics for grad, do we need to change anything, except perhaps to avoid this potential confusion by making grad error semantics consistent in the positional and SPMD worlds?

Maybe yes. One problem with the current semantics is that if we make @sharadmv's use of grad here an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they'd have to write vjp-with-ones themselves, e.g. by defining grad2 as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we're calling grad/grad2 on a scalar-input scalar-output function (but for the closed-over value of data which is different on each device) with the same primal input value on every device, yet getting different grad2 results on different devices (perhaps not noticing that if we looked at the primal output value we'd also have a different value on each device, which might make getting different gradients less surprising). That is, in the expression grad(agg_loss)(w) in @sharadmv 's original code, the function agg_loss is a different function on each device because it closes over mapped data, which is why we should expect to get different answers.

In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for grad) and defensible, this example has clearly shown that they can be confusing, especially when differentiating functions that close over mapped values. The main difference with @jekbradbury 's proposed semantics (as he's pointed out too) is whether psum corresponds in bulk array land to a reduce-sum, or a reduce-sum-then-broadcast.

jekbradbury commented 4 years ago

As Matt points out, the semantics of SPMD programs at head are more consistent than I thought—they just correspond to different (and in my view less useful) bulk array semantics. Adopting my preferred approach would be breaking, and would need to be sequenced carefully with other enhancements we're planning to make.

In particular: at head, values in SPMD code always "contain" the pmapped axis (by which I mean that their bulk array counterparts always contain the corresponding logical axis). This has a few consequences:

  1. pmap has to insert pbroadcast of in_axis=None or closed-over values at the beginning of the mapped region, rather than later on/when needed, because a value inside a pmap that hasn't been pbroadcasted isn't representable
  2. psum has to have reduce-sum-broadcast semantics, because the output of reduce-sum without a broadcast isn't representable
  3. grad inside a pmap has to correspond to bulk vjp-with-ones because a bulk scalar isn't representable

These are essentially the things that surprised Sharad. (1) meant that the output of the gradient function wasn't psummed, because the input of the forward function had already been broadcasted; (2) meant that the values flowing back through the VJP were 2 rather than 1, because those values had been psummed; and (3) meant that the output values computed were no longer gradients, since the bulk version of the forward function had a non-scalar output.

Collectively they mean that a pattern that would otherwise be natural and useful, and represents the SPMD counterpart of standard bulk code for NN training—taking the gradient of a psummed loss with respect to replicated parameters—can't be expressed in SPMD JAX. There are two alternatives: using grad of pmap, which has significant runtime overheads, and moving the psum outside the loss function, which is the most common approach in data-parallel JAX today, but represents an unfortunate loss of expressiveness and a gap between SPMD code and bulk code.

These things and others would become easier, and the system more aligned with Sharad's (and my) mental model, if SPMD JAX were extended with the ability to represent values that don't contain the mapped axis. This is a strict increase in expressiveness, because it increases the set of bulk programs that have SPMD counterparts. It would not require a change in the semantics of existing programs, but it would open up the possibility of certain changes (and wouldn't be very useful without them):

  1. pmap can now wait to insert pbroadcast of in_axis=None or closed-over values until they're mixed with values that contain the mapped axis
  2. psum can correspond to bulk reduce-sum
  3. grad inside a pmap can correspond to bulk grad

With these changes, Sharad's code would work out of the box with his desired semantics, and the set of representable programs becomes more symmetric between the SPMD and bulk worlds.

But of course the "most common approach in data-parallel JAX today" (psum/pmean of the result of grad inside a pmap) assumes per-device grad with vjp-with-ones bulk semantics! So we have to be careful about how we make changes like these.

We could, for instance, introduce kwargs to control these behaviors (e.g., keepdims on psum could default to True at first). We could also sequence these changes to take place alongside new APIs we're adding or considering adding, like gmap and a potential sum-psum unification.

sharadmv commented 4 years ago

But of course the "most common approach in data-parallel JAX today" (psum/pmean of the result of grad inside a pmap) assumes per-device grad with vjp-with-ones bulk semantics! So we have to be careful about how we make changes like these.

How common is grad(psum) in practice? This case pops up when there is a psum in a loss function, but if in practice people are doing psum of gradients after the fact, presumably they are writing "local" loss functions that do not have a psum in them. This means the vjp-of-ones semantics is the same as several vjp-of-scalar-loss in parallel (since there are no collectives in the loss function). Wouldn't psum of gradients of a loss that has a psum in it produce unintuitive values too?

martin-marek commented 2 years ago

I came across this problem when implementing an HMC sampler, which depends on the input data only through the log-posterior density $\log p(\theta|\mathbf{x})$. The log-posterior of the whole dataset is not equal to the sum of log-posteriors that we would compute for each batch separately: $\log p(\theta|\mathbf{x}) \ne \sum \log p(\theta|x_i)$. So, in this particular use case, taking the grad of psum is useful.

To understand what's going on, I first replicated the toy problem discussed in this thread using a traditional single-device program. I think the results are intuitive:

import jax.numpy as jnp
from jax import lax, vmap, pmap, grad

# data
w = jnp.array(1.)
data = jnp.array([0., 1.])

# loss function
def loss_fn(w, data):
    return (w * data).sum()

# scalar w
print(grad(loss_fn)(w, data))
# expected: 1.
# prints: 1. ✓

# vector w
print(grad(loss_fn)(lax.broadcast(w, (2,)), data))
# expected: [0. 1.]
# prints: [0. 1.] ✓

Next, I modified the above code into an SPMD program, by replacing sum with psum. My intuition was that psum should behave like sum followed by a broadcast along devices. However, this is not the case:

# loss function
def loss_fn(w, data):
    return lax.psum(w * data, 'i')

# scalar w
print(pmap(grad(loss_fn), 'i', in_axes=(None, 0))(w, data))
# expected: [1. 1.]
# prints: [0. 2.] ✗

# vector w
print(pmap(grad(loss_fn), 'i')(lax.broadcast(w, (2,)), data))
# expected: [0. 1.]
# prints: [0. 2.] ✗

I understand that the first case is tricky (as discussed above). However, why should the second output be [0. 2.] rather than [0. 1.]?

sharadmv commented 2 years ago

This exact issue was, for me, motivated by doing pmap of Hamiltonian Monte Carlo. It's pretty tricky to get the gradients right in general. I encourage you to take a look at this guide that uses TFP: https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX.

The idea is that we introduce "pbroadcasts" when unmapped random variables interact with mapped random variables. The transpose of pbroadcast is psum and vice versa. The low level implementation can be found here: https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/internal/distribute_lib.py.