jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.3k stars 2.78k forks source link

Modifiable buffers #926

Closed rtqichen closed 5 years ago

rtqichen commented 5 years ago

Is there currently a way to have update-able buffers that aren't parameters in the stax framework (and so won't get updated by gradient descent but by some other user-specified method)?

This is very important for many deep learning applications. For example, running averages (e.g. for batch normalization, #139 seems to be a special case of this), the u and v vectors from power iteration for spectral normalization, differentiable plasticity.

One simple solution would be to allow in-place modification of certain arrays. These can then be placed in the scope of a stax function without being returned by init_fun. This probably goes against the jax design principle of being functional though, so any pointers to best practices would be appreciated!

mattjj commented 5 years ago

Thanks for raising this!

I don't think in-place modification is necessary to collect batch norm statistics, in the same way that in-place modification isn't necessary to update the parameters, though stax itself isn't set up to collect batch norm statistics. (I think the same is true for power iteration, and any iterative scheme, though a more detailed example of what you're thinking about might help.)

Actually, the way most stax code is written, it also doesn't do in-place updating of parameters, unless the update loop is under a jit.

Under a jit, XLA can automatically generate code that performs in-place modification of memory buffers when possible. It's not necessary for the programmer to manually manage memory buffers; the programmer only needs to think about values. Consider this abstract iterative scheme:

for x in xs:
  y = f(y, x)

That abstract formulation models updating parameters of a neural network, or accumulating statistics like batch-norm statistics, or even power iteration. When the whole loop is under a jit XLA will happily reuse the memory buffer that stores y (unless it decides it can do something even more efficient). No in-place updating is needed in the program for it to be executed efficiently!

That only applies when the whole loop is under a jit though, since that's when XLA gets to optimize things. Another thing we might want, and which is on our todo list somewhere, is to allow donating argument buffers to XLA. Consider the case where the above loop is not under a jit, but just the function f is. Then the buffer for y won't be modified in-place because XLA doesn't know that the input and output are effectively aliased; that is, it doesn't know we're about to drop the reference to the original y buffer, and so it doesn't know it is free to perform in-place updating of that buffer. That's what's happening now with neural net training (and any other iterative scheme) when the training loop isn't under a jit even if the update function is; that means there's not yet a good way to train a model for which we can only afford to store one copy of the parameters.

Writing programs with explicit in-place modification isn't necessary for efficiency; in fact, it constraints the ability of a compiler to optimize code, which is why the XLA compiler itself (which is designed for the highest performance possible) doesn't support programs with in-place updating semantics, even though it generates code with a lot of in-place updating operations. But the even bigger reason for JAX not to support in-place updating semantics in user programs is that it makes transforming code, e.g. with autodiff or with vmap, much more complex and maybe impossible. It's not that we like functional code for some aesthetic reason; it's all a practical consideration, because we want composable function transformations and the fastest performance.

I can think of two main requests here:

  1. adapt stax and optimizers to make it easy to express the collection of batch norm statistics
  2. add a way to donate buffers to XLA (perhaps automatically using Python runtime refcounting) so that y = f(y, x) can perform in-place updating when f is jitted but the line itself is outside of a jit.

Do one or both of those capture your main ask here, or did I miss it?

rtqichen commented 5 years ago

Thank you for the detailed reply!

Yes, I see now that a more general API for stax and optimizers would be ideal for the shortish term. (Perhaps the method of updating can somehow be passed between them, but in a way that'd decoupled from the parameter itself because some cases may require updating parameters differently depending on the situation.)

Donating buffers to JAX sounds very interesting for improving jit, but relying on jit in order to obtain in-place modification (which is only one particular way of achieving stateful apply_funs) seems too fragile.

mattjj commented 5 years ago

but relying on jit in order to obtain in-place modification [...] seems too fragile

Maybe, but relying on jit for performance and reaping the composable transformation benefits of functionally pure code is kind of what we're all about :)

Actually, I think relying on mutation semantics is more fragile. That's what leads to tf.control_dependencies and other surprises in TF.

hawkinsp commented 5 years ago

Closing this issue: I think we're fairly comfortable that immutable buffers are the right design choice. Of course, if you have concrete cases where immutability causes problems, we can look into them specifically.

romanngg commented 5 years ago

Here's an example of a program that AFAIK needs donating buffers to XLA:

import jax.numpy as np
import jax.random as random
from jax import lax
from jax import jit

@jit
def f(x):
  for _ in range(10):
    x = lax.conv_general_dilated(x, np.ones((3, 3, 1, 1)), (1, 1), 'SAME', 
                                 dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
  return x

x = random.normal(random.PRNGKey(1), (2**19, 2**5, 2**5, 1))  
# (2**20, 2**5, 2**5, 1)) OOMs!
x = f(x)

Peak memory allocation of this program on a 2Gb input x is 4x2Gb = 8Gb, while I believe it should be only 2x2Gb, and I should be able to fit a 4Gb input, but I cannot currently.