jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.18k stars 2.76k forks source link

In-place updating can lead to counterintuitive behavior #226

Open craffel opened 5 years ago

craffel commented 5 years ago

Simple example:

import jax.numpy as np
import numpy as onp

a = onp.array([10, 20])
b = a
b += 10
print(a)  # [20 30]

a = np.array([10, 20])
b = a
b += 10
print(a)  # [10 20]

This difference is counterintuitive if you're used to numpy's referencing behavior. This came up when looping over lists of jax.numpy.ndarrays:

# params and grads are lists of jax.numpy.ndarray
for param, gradient in zip(params, grads):
    # params does not get updated
    param -= learning_rate * gradient

A workaround suggested by @mattjj which is arguably at least as Pythonic:

params = [param - learning_rate * gradient
          for param, gradient in zip(params, grads)]
mattjj commented 5 years ago

For now to avoid silent surprises we should probably override __iadd__ and its ilk to raise an error, since otherwise Python will desugar b += 10 to b = b + 10 and thus behave differently from NumPy. We could add in-place operations at some point, probably with different semantics than NumPy's, but for now we should probably be explicit.

n2cholas commented 3 years ago

Any updates on this? I think this is a big pitfall behaviour for new users. Matching numpy's read-only behaviour here would be great:

import numpy as np

x = np.arange(5)
x.flags.writeable = False
x += 1
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-42-4068e55d63c2> in <module>()
      3 x = np.arange(5)
      4 x.flags.writeable = False
----> 5 x += 1

ValueError: output array is read-only