jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

Converting from a list to jnp array is unnecessarily slow #19500

Open timellemeet opened 9 months ago

timellemeet commented 9 months ago

Please:

Given an arbitrary native list, for example test_list = list(range(50000). We found a more than 20x performance difference converting it to a jnp array via numpy vs directly.

This seems quite unnecessary and might be unexpected for users.

jakevdp commented 9 months ago

This is expected, unfortunately. If you want fast array conversion for large Python lists, I'd suggest converting to numpy array first.

timellemeet commented 9 months ago

But would it not make sense to convert it like this within the library in that case?

jakevdp commented 9 months ago

We do convert to numpy array first if it's possible, but checking whether that is possible is what takes extra time (a cost that scales linearly with the size of the list).

In general, working with lists in JAX is an anti-pattern, and we raise errors in this case nearly everywhere (e.g. jnp.sum([1, 2, 3]) will error). jnp.array is one place where we've allowed list inputs, mainly for convenience.

You might wonder why we don't do something like:

def array(x):
  try:
    x = np.asarray(x)
  except:
    pass
  # ...continue with JAX logic

But the problem is there are cases where this would succeed but is semantically the wrong thing. For example, something like this:

x = jnp.zeros(10)
y = jnp.ones(10)
jnp.array([x, y])

np.asarray would succeed here, but would do so by transferring the contents of x and y to the host via their __array__ method, which in general is not something we want to do silently.

So in jnp.array we first check whether there are any arrays and tracers in the list, and this check ends up being expensive for long lists.

The only alternative would be to error on all list inputs to jnp.array, but we've judged that the convenience of allowing lists is worth the potential performance foot-gun.

What do you think?

timellemeet commented 9 months ago

Personally I would go for consistency and also throw an error here in that case. I think anyone using JAX is old and wise enough to know how to handle that error and convert something to numpy. More convenient than having to find out about this performance difference the hard way.

If not I would at least at it with a big note in docs, especially if this is the only exception where lists are allowed.

Just my two cents.