Closed rtqichen closed 5 years ago
Thanks for raising this!
I don't think in-place modification is necessary to collect batch norm statistics, in the same way that in-place modification isn't necessary to update the parameters, though stax itself isn't set up to collect batch norm statistics. (I think the same is true for power iteration, and any iterative scheme, though a more detailed example of what you're thinking about might help.)
Actually, the way most stax code is written, it also doesn't do in-place updating of parameters, unless the update loop is under a jit
.
Under a jit
, XLA can automatically generate code that performs in-place modification of memory buffers when possible. It's not necessary for the programmer to manually manage memory buffers; the programmer only needs to think about values. Consider this abstract iterative scheme:
for x in xs:
y = f(y, x)
That abstract formulation models updating parameters of a neural network, or accumulating statistics like batch-norm statistics, or even power iteration. When the whole loop is under a jit
XLA will happily reuse the memory buffer that stores y
(unless it decides it can do something even more efficient). No in-place updating is needed in the program for it to be executed efficiently!
That only applies when the whole loop is under a jit
though, since that's when XLA gets to optimize things. Another thing we might want, and which is on our todo list somewhere, is to allow donating argument buffers to XLA. Consider the case where the above loop is not under a jit
, but just the function f
is. Then the buffer for y
won't be modified in-place because XLA doesn't know that the input and output are effectively aliased; that is, it doesn't know we're about to drop the reference to the original y
buffer, and so it doesn't know it is free to perform in-place updating of that buffer. That's what's happening now with neural net training (and any other iterative scheme) when the training loop isn't under a jit
even if the update function is; that means there's not yet a good way to train a model for which we can only afford to store one copy of the parameters.
Writing programs with explicit in-place modification isn't necessary for efficiency; in fact, it constraints the ability of a compiler to optimize code, which is why the XLA compiler itself (which is designed for the highest performance possible) doesn't support programs with in-place updating semantics, even though it generates code with a lot of in-place updating operations. But the even bigger reason for JAX not to support in-place updating semantics in user programs is that it makes transforming code, e.g. with autodiff or with vmap
, much more complex and maybe impossible. It's not that we like functional code for some aesthetic reason; it's all a practical consideration, because we want composable function transformations and the fastest performance.
I can think of two main requests here:
y = f(y, x)
can perform in-place updating when f
is jitted but the line itself is outside of a jit
.Do one or both of those capture your main ask here, or did I miss it?
Thank you for the detailed reply!
Yes, I see now that a more general API for stax and optimizers would be ideal for the shortish term. (Perhaps the method of updating can somehow be passed between them, but in a way that'd decoupled from the parameter itself because some cases may require updating parameters differently depending on the situation.)
Donating buffers to JAX sounds very interesting for improving jit
, but relying on jit
in order to obtain in-place modification (which is only one particular way of achieving stateful apply_fun
s) seems too fragile.
but relying on jit in order to obtain in-place modification [...] seems too fragile
Maybe, but relying on jit
for performance and reaping the composable transformation benefits of functionally pure code is kind of what we're all about :)
Actually, I think relying on mutation semantics is more fragile. That's what leads to tf.control_dependencies
and other surprises in TF.
Closing this issue: I think we're fairly comfortable that immutable buffers are the right design choice. Of course, if you have concrete cases where immutability causes problems, we can look into them specifically.
Here's an example of a program that AFAIK needs donating buffers to XLA:
import jax.numpy as np
import jax.random as random
from jax import lax
from jax import jit
@jit
def f(x):
for _ in range(10):
x = lax.conv_general_dilated(x, np.ones((3, 3, 1, 1)), (1, 1), 'SAME',
dimension_numbers=('NHWC', 'HWIO', 'NHWC'))
return x
x = random.normal(random.PRNGKey(1), (2**19, 2**5, 2**5, 1))
# (2**20, 2**5, 2**5, 1)) OOMs!
x = f(x)
Peak memory allocation of this program on a 2Gb input x
is 4x2Gb = 8Gb, while I believe it should be only 2x2Gb, and I should be able to fit a 4Gb input, but I cannot currently.
Is there currently a way to have update-able buffers that aren't parameters in the stax framework (and so won't get updated by gradient descent but by some other user-specified method)?
This is very important for many deep learning applications. For example, running averages (e.g. for batch normalization, #139 seems to be a special case of this), the u and v vectors from power iteration for spectral normalization, differentiable plasticity.
One simple solution would be to allow in-place modification of certain arrays. These can then be placed in the scope of a stax function without being returned by init_fun. This probably goes against the jax design principle of being functional though, so any pointers to best practices would be appreciated!