Open sharadmv opened 4 years ago
This is confusing! It seems to be an issue about how we define transposes of SPMD functions (i.e. it's not just a dumb software bug). No conclusions yet, but wanted to dump some thoughts here.
Here are some test programs:
import jax.numpy as jnp
from jax import vmap, pmap, grad, lax
def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]
def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]
def distributed(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(distributed, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]
The jaxpr being transposed in the last one is
{ lambda c ; b.
let d = mul b c
e = psum[ axis_index_groups=None
axis_name=batch ] d
in (e,) }
Operationally, the transpose evaluation proceeds like this: grad
feeds in a cotangent for e
of 1.0
. Then we compute a cotangent for d
of 2.0
, reflecting the fact that if we perturb the value of d
by epsilon (on each replica) then the value of e
changes by 2 epsilon. Then we compute a cotangent for b
by multiplying by 0.
on the first replica and 1.
on the second replica, leading to the final result [0. 2.]
.
Beyond the numbers, we have a symmetry issue. The first two versions produce a symmetric-across-replicas result, which makes sense because agg_loss
's last operation is a psum
and so its output must be symmetric across replicas. But the reverse-mode version can't produce a symmetric result because it multiplies by the mapped value data
at the end. (The symmetric result makes the most sense because agg_loss
always produces a symmetric result given a symmetric input.)
Hmmmm...
IMO, we need to keep track of replication to do a good job here. I see two reasonable behaviors:
Behavior A: grad of a replicated thing throws an error, and grad of a non-replicated thing computes a separate grad for each thread Behavior B: grad of a non-replicated thing throws an error (the idea being that grad requires a scalar function, and an SPMD scalar function isn't a scalar function), and grad of a replicated thing computes a single (replicated) summed grad
Or I guess both could work?
("replicated" means guaranteed-symmetric, out_axes=None
is allowed)
Can you define 'replicated thing' and 'non-replicated thing'? Do you mean a function (agg_loss
in this case), and if so what's it mean to be replicated (maybe: has any collectives in it) ?
By "replicated thing" I mean a guaranteed-symmetric value, i.e. one where out_axes=None
would be allowed if it were a return value of the map.
In terms of functions, I mean those whose return values are replicated.
Thanks! One possible desideratum is not to do something different on guaranteed-symmetric values versus just-happens-to-be-symmetric values. Notice that in this example if data
just happens to be symmetric, we get the right answer. Also, if we noticed (say in backward_pass
) that the function to be transposed is symmetric (in the sense that when it's given a symmetric-acoss-replicas input it produces a symmetric output), and in that case always symmetrized the cotangents of the jaxpr inputs, then we'd compute the right answer in both guaranteed-symmetric (new!) and just-happens-to-be-symmetric (already got that one) cases.
EDIT: I think that, minus the "throws an error" parts, this may be what you meant by "both". Is that right?
I'm starting to agree with James that we might need to track spmd-symmetry in jaxprs, and reverse mode should do something interesting with that information. It's tricky: if we have a jaxpr to transpose with both symmetric and non-symmetric outputs, we may need to run two backward passes, one for the symmetric outputs (where we symmetrize the resulting input cotangnents) and one for the non-symmetric outputs (where we don't).
I have thought about this issue a bit in another context and I think you need the notion/type of a SPMD-global scalar to make grad make sense with internal pmap. The API difficulty with pmap is that it has no way to return something of rank < 1 so semantically the program never has a function on which grad should be defined since the output is never a scalar. Have a ReplicatedArray be a tensor in some contexts and a scalar in others seems like it could lead to semantic problems.
I think we shouldn't necessarily expect invariants that are true in ordinary code to also hold for SPMD code; we often assume that SPMD code will behave like a "batch of (independent) programs", but the presence of collectives means that can be a bad analogy! Instead, I like to think of SPMD programming as a different perspective on bulk array programming (the named axis is a real axis that just doesn't have a position, and collectives are just bulk primitives, like reductions, that operate on that axis).
That means that a scalar -> scalar SPMD program is not actually a scalar -> scalar program, and nor is it a batch of scalar -> scalar programs. It's a vector -> vector program (that acts elementwise/has a diagonal Jacobian if there aren't any collectives, and doesn't if there are). This is a particularly important distinction for autodiff, since a single application of forward-mode autodiff or finite differences can only give the derivative of a scalar ->
function, while a single application of reverse-mode autodiff can only give the derivative of a -> scalar
function (with an exception in both cases if the Jacobian is known to be diagonal).
In bulk array terms, Matt's three examples are:
import jax.numpy as jnp
from jax import vmap, pmap, grad, lax, jvp
def f_jvp(w, data):
def agg_loss(w):
return jnp.sum(w * data, 0)
return jvp(agg_loss, (w,), (1.,))[1]
print(f_jvp(jnp.ones(2), jnp.arange(2)))
# prints 1.0
def f_fd(w, data):
def agg_loss(w):
return jnp.sum(w * data, 0)
return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(f_fd(jnp.ones(2), jnp.arange(2)))
# prints 1.0000467
def f_grad(w, data):
def agg_loss(w):
return jnp.sum(w * data, 0)
return grad(agg_loss)(w)
print(f_grad(jnp.ones(2), jnp.arange(2)))
# prints [0. 1.]
These results are a little less surprising! Here agg_loss
is an ordinary vector -> scalar function and f_grad
is computing the gradient, while f_jvp
and f_fd
are both computing something else (the product of the Jacobian with a vector of ones).
In this bulk array context, another way to compute that same Jacobian-vector product is to take the grad
of a scalar -> scalar function that broadcasts its input (expressed with the same agg_loss
Python code through NumPy rank polymorphism/implicit broadcasting):
def f_grad_bcast(w, data):
def agg_loss(w):
return jnp.sum(w * data, 0)
return grad(agg_loss)(w)
print(f_grad_bcast(1., jnp.arange(2)))
# prints 1.0
Here the forward pass performs a broadcast, multiply, then sum, so the transpose performs a broadcast, multiply, then sum too (broadcast is fan-out and its transpose is fan-in-sum). How to make sure the transpose contains a sum when the forward pass actually needed to broadcast, but doesn't when it didn't, is the implicit broadcasting problem in autodiff systems for dynamically-ranked languages like TF graphs and TorchScript.
My claim is that all of these concerns apply in the SPMD case, too. An SPMD program cannot run except with all of its named axes bound by corresponding maps, and the resulting program is a bulk array program with well-defined semantics. We need to treat those semantics as the source of truth for thinking about SPMD programs, rather than intuitions about batches of independent programs, because those intuitions break down when collectives are involved.
In bulk array programs, a particular logical axis can either be present or absent in each of the arguments and return values. In the context of bulk array programs where that axis was created by mapping an SPMD axis using JAX's map APIs, this corresponds to whether the in_axes
/out_axes
setting for each argument/return is an integer or None
. Here's what I think that means for the SPMD versions of our examples:
import jax.numpy as jnp
from jax import vmap, pmap, grad, lax, jvp
def f_jvp(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f_jvp, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1. 1.]
def f_fd(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return (agg_loss(w+1e-3) - agg_loss(w)) / 1e-3
print(pmap(f_fd, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# prints [1.0000466 1.0000466]
def f_grad(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(f_grad, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# should print [0. 1.]
def f_grad_bcast(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(f_grad_bcast, axis_name='batch', in_axes=(None, 0))(1., jnp.arange(2)))
# should print [1. 1.], or just 1.0 with out_axes=None
(That last one is Sharad's code.)
We don't currently produce the desired output in the last case, because we include a spurious psum
at the beginning of the transpose (from [1. 1.] to [2. 2.]) and don't include a necessary one at the end. There might be multiple ways to fix the system and get the right answer, but my understanding is that the best way forward looks something like this:
First, we should distinguish between "scalars" and "batched scalars" with respect to a particular SPMD axis, and introduce a pbroadcast
to take a scalar to a batched scalar with the same value (while psum
should take a batched scalar and produce a scalar). Then, the transpose of pbroadcast
should be psum
and vice versa, and the cotangent of a scalar should always be a scalar and the cotangent of a batched scalar should always be a batched scalar. This means that (just like in the rank-polymorphic bulk case!) we should insert a pbroadcast
when a scalar needs to be promoted to a batched scalar (as in w * data
) and the backward pass should have a psum
there.
Ideally we would also match the bulk-array error semantics too, and constrain the return value of the function passed to grad
to be a scalar, not a batched scalar; we should also support out_axes=None
in maps so that we can return these scalars out to the bulk world and make the semantics less obscure.
@jekbradbury that's brilliant! Let's do it.
I changed the issue title because I think the current (i.e. at HEAD) semantics make sense, but can make grad(f)(x)
and jvp(f, (x,), (1.,))
differ for scalar-input scalar-output f
when inside a pmap
(as in my comment above), which seems worth revising. The current semantics just take a different correspondence to bulk array programs than the one in @jekbradbury's previous comment.
The original intention with psum
was to correspond to a bulk array operation that included broadcasting, not just a reduce-sum as in James's comment above:
pmap(lambda x: psum(x, 'i'), axis_name='i')
==
lambda x: lax.broadcast(x.sum(0), (num_devices,))
So given this pmap
code:
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
we should expect it to behave like this bulk array version:
import jax.numpy as jnp
from jax import grad
from jax import lax
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,))
return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
But that code triggers an error, because in api.py we check to ensure grad
is only being applied to scalar-output functions. Just for defining semantics, let's sidestep that error check:
import jax.numpy as jnp
from jax import vjp
from jax import lax
def grad(f):
def gradfun(x):
ans, f_vjp = vjp(f, x)
x_bar, = f_vjp(jnp.ones_like(ans))
return x_bar
return gradfun
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,)) # notice the broadcast!
return grad(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# prints [0. 2.]
These SPMD semantics explain the reverse-mode answer in my first comment above.
The trouble with this definition, though, is that it can make grad
disagree with jvp
and numerical differences on what appears locally to be a scalar-input scalar-output function when inside a pmap, as in the examples in my first comment. I want to write out an explanation for what's going on, but I've got to step away for a moment and wanted to send this comment first. To be continued!
In any case, I think @jekbradbury 's proposal for revising the semantics is likely to be better. I just want to pin down both the old and new semantics as best we can.
Okay, back to it!
Notice the Jacobian of the agg_loss
function written with the broadcast is [[0, 1], [0, 1]]
:
import jax.numpy as jnp
from jax import lax
from jax import jacfwd
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(0), (2,))
return jacfwd(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [[0. 1.]
# [0. 1.]]
So, while keeping the current (at HEAD) bulk array definition of psum
as a reduce-sum followed by a broadcast, the SPMD AD semantics is consistent so long as we take grad
to mean "compute the VJP against a ones vector broadcast along all named axes":
import jax.numpy as jnp
from jax import vjp, jvp, pmap, grad
from jax import lax
### reverse-mode
# At HEAD, we define this SPMD program:
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return grad(agg_loss)(w)
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [0. 2.]
# To mean the same as this bulk array vjp-with-ones program (`grad` is always
# defined as vjp-with-ones plus an error check for scalar outputs that we don't
# include in the definition of SPMD semantics):
def grad2(f):
def gradfun(x):
ans, f_vjp = vjp(f, x)
x_bar, = f_vjp(jnp.ones_like(ans))
return x_bar
return gradfun
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,)) # bulk array version of psum
return grad2(agg_loss)(w)
print(f(jnp.ones(2), jnp.arange(2)))
# [0. 2.]
# ### forward-mode
# At HEAD, we define this SPMD program:
def f(w, data):
def agg_loss(w):
return lax.psum(w * data, 'batch')
return jvp(agg_loss, (w,), (1.,))[1]
print(pmap(f, axis_name='batch')(jnp.ones(2), jnp.arange(2)))
# [1. 1.]
# To mean the same as this bulk array jvp-with-ones program:
def f(w, data):
def agg_loss(w):
return lax.broadcast((w * data).sum(), (2,)) # bulk array version of psum
return jvp(agg_loss, (w,), (jnp.ones_like(w),))[1]
print(f(jnp.ones(2), jnp.arange(2)))
# [1. 1.]
(In Autograd, like TF today, we used to define grad
as vjp-with-ones, but exactly this sort of confusion is why we made it raise an error for non-scalar outputs. Yet we didn't make that check work with SPMD functions, in the sense that grad
will in effect happily allow broadcasting along named mapped axes!)
If the semantics at HEAD are self-consistent, except the error semantics for grad
, do we need to change anything, except perhaps to avoid this potential confusion by making grad
error semantics consistent in the positional and SPMD worlds?
Maybe yes. One problem with the current semantics is that if we make @sharadmv's use of grad
here an error (rather than a vjp-with-ones) not only would that have been surprising to him, but also it would break pretty much all existing SPMD neural net training; they'd have to write vjp-with-ones themselves, e.g. by defining grad2
as above. Even then, the answers can be surprising: within the context of the SPMD function, it looks like we're calling grad
/grad2
on a scalar-input scalar-output function (but for the closed-over value of data
which is different on each device) with the same primal input value on every device, yet getting different grad2
results on different devices (perhaps not noticing that if we looked at the primal output value we'd also have a different value on each device, which might make getting different gradients less surprising). That is, in the expression grad(agg_loss)(w)
in @sharadmv 's original code, the function agg_loss
is a different function on each device because it closes over mapped data, which is why we should expect to get different answers.
In any case, while I now think the semantics at HEAD are actually consistent (modulo error semantics for grad
) and defensible, this example has clearly shown that they can be confusing, especially when differentiating functions that close over mapped values. The main difference with @jekbradbury 's proposed semantics (as he's pointed out too) is whether psum
corresponds in bulk array land to a reduce-sum, or a reduce-sum-then-broadcast.
As Matt points out, the semantics of SPMD programs at head are more consistent than I thought—they just correspond to different (and in my view less useful) bulk array semantics. Adopting my preferred approach would be breaking, and would need to be sequenced carefully with other enhancements we're planning to make.
In particular: at head, values in SPMD code always "contain" the pmapped axis (by which I mean that their bulk array counterparts always contain the corresponding logical axis). This has a few consequences:
pmap
has to insert pbroadcast
of in_axis=None
or closed-over values at the beginning of the mapped region, rather than later on/when needed, because a value inside a pmap
that hasn't been pbroadcast
ed isn't representablepsum
has to have reduce-sum-broadcast semantics, because the output of reduce-sum without a broadcast isn't representablegrad
inside a pmap
has to correspond to bulk vjp-with-ones because a bulk scalar isn't representableThese are essentially the things that surprised Sharad. (1) meant that the output of the gradient function wasn't psum
med, because the input of the forward function had already been broadcasted; (2) meant that the values flowing back through the VJP were 2 rather than 1, because those values had been psum
med; and (3) meant that the output values computed were no longer gradients, since the bulk version of the forward function had a non-scalar output.
Collectively they mean that a pattern that would otherwise be natural and useful, and represents the SPMD counterpart of standard bulk code for NN training—taking the gradient of a psummed loss with respect to replicated parameters—can't be expressed in SPMD JAX. There are two alternatives: using grad
of pmap
, which has significant runtime overheads, and moving the psum
outside the loss function, which is the most common approach in data-parallel JAX today, but represents an unfortunate loss of expressiveness and a gap between SPMD code and bulk code.
These things and others would become easier, and the system more aligned with Sharad's (and my) mental model, if SPMD JAX were extended with the ability to represent values that don't contain the mapped axis. This is a strict increase in expressiveness, because it increases the set of bulk programs that have SPMD counterparts. It would not require a change in the semantics of existing programs, but it would open up the possibility of certain changes (and wouldn't be very useful without them):
pmap
can now wait to insert pbroadcast
of in_axis=None
or closed-over values until they're mixed with values that contain the mapped axispsum
can correspond to bulk reduce-sumgrad
inside a pmap
can correspond to bulk grad
With these changes, Sharad's code would work out of the box with his desired semantics, and the set of representable programs becomes more symmetric between the SPMD and bulk worlds.
But of course the "most common approach in data-parallel JAX today" (psum
/pmean
of the result of grad
inside a pmap
) assumes per-device grad
with vjp-with-ones bulk semantics! So we have to be careful about how we make changes like these.
We could, for instance, introduce kwargs to control these behaviors (e.g., keepdims
on psum
could default to True at first). We could also sequence these changes to take place alongside new APIs we're adding or considering adding, like gmap
and a potential sum
-psum
unification.
But of course the "most common approach in data-parallel JAX today" (psum/pmean of the result of grad inside a pmap) assumes per-device grad with vjp-with-ones bulk semantics! So we have to be careful about how we make changes like these.
How common is grad(psum)
in practice? This case pops up when there is a psum
in a loss function, but if in practice people are doing psum
of gradients after the fact, presumably they are writing "local" loss functions that do not have a psum
in them. This means the vjp-of-ones semantics is the same as several vjp-of-scalar-loss in parallel (since there are no collectives in the loss function). Wouldn't psum
of gradients of a loss that has a psum
in it produce unintuitive values too?
I came across this problem when implementing an HMC sampler, which depends on the input data only through the log-posterior density $\log p(\theta|\mathbf{x})$. The log-posterior of the whole dataset is not equal to the sum of log-posteriors that we would compute for each batch separately: $\log p(\theta|\mathbf{x}) \ne \sum \log p(\theta|x_i)$. So, in this particular use case, taking the grad of psum is useful.
To understand what's going on, I first replicated the toy problem discussed in this thread using a traditional single-device program. I think the results are intuitive:
import jax.numpy as jnp
from jax import lax, vmap, pmap, grad
# data
w = jnp.array(1.)
data = jnp.array([0., 1.])
# loss function
def loss_fn(w, data):
return (w * data).sum()
# scalar w
print(grad(loss_fn)(w, data))
# expected: 1.
# prints: 1. ✓
# vector w
print(grad(loss_fn)(lax.broadcast(w, (2,)), data))
# expected: [0. 1.]
# prints: [0. 1.] ✓
Next, I modified the above code into an SPMD program, by replacing sum
with psum
. My intuition was that psum
should behave like sum
followed by a broadcast along devices. However, this is not the case:
# loss function
def loss_fn(w, data):
return lax.psum(w * data, 'i')
# scalar w
print(pmap(grad(loss_fn), 'i', in_axes=(None, 0))(w, data))
# expected: [1. 1.]
# prints: [0. 2.] ✗
# vector w
print(pmap(grad(loss_fn), 'i')(lax.broadcast(w, (2,)), data))
# expected: [0. 1.]
# prints: [0. 2.] ✗
I understand that the first case is tricky (as discussed above). However, why should the second output be [0. 2.]
rather than [0. 1.]
?
This exact issue was, for me, motivated by doing pmap of Hamiltonian Monte Carlo. It's pretty tricky to get the gradients right in general. I encourage you to take a look at this guide that uses TFP: https://www.tensorflow.org/probability/examples/Distributed_Inference_with_JAX.
The idea is that we introduce "pbroadcasts" when unmapped random variables interact with mapped random variables. The transpose of pbroadcast is psum and vice versa. The low level implementation can be found here: https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/internal/distribute_lib.py.
Running this on a machine with >=2 devices:
The losses are correct but gradients are incorrect across each shard.
What should be the intended behavior here?