Closed clemisch closed 4 years ago
I think that the problem is in np.cumsum
. The following:
big_x = onp.random.randn(size)
np.cumsum(big_x, axis=0)
is much slower than onp.cumsum
. I have checked on CPU for now.
@hawkinsp do you have any advice of what I should try next?
I think the issue is that the algorithm implemented by cumsum
is quadratic time if XLA implements it naively:
https://github.com/google/jax/blob/6b157ff91cd9b0030e62b43e857fcecc32cfdf8b/jax/numpy/lax_numpy.py#L1544
I think only the TPU does something smarter than the naive implementation, so it's unsurprising this takes forever with an input size of 1e7 on CPU and GPU.
My personal temptation here would be to try implementing a Blelloch-style parallel sum scan algorithm using gather and scatter-add. There's a good blog post explaining them here: https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
I'm also curious what happens if you autodiff the Blelloch algorithm. Currently cumprod
isn't arbitrarily differentiable, but I don't see any fundamental reason one cannot differentiate through the Blelloch algorithm at an 2n log(n)
space cost, which doesn't seem unreasonable for reverse-mode autodiff.
Thank you for your insight and the link to the algorithm!
Here's a version of cumprod
that has asymptotically better complexity although only works for power of 2 array sizes:
@jax.jit
def cumprod_v2(z):
n = len(z)
log2 = int(math.log2(n - 1))
zs = []
for d in range(0, log2):
z1 = lax.slice(z, (0,), (len(z),), (2,))
zs.append(z1)
z2 = lax.slice(z, (1,), (len(z),), (2,))
z = z1 * z2
zs.append(lax.slice(z, (0,), (len(z),), (2,)))
dtype = jnp.dtype(z.dtype).type
z = jnp.array([1], dtype=dtype)
for w in reversed(zs):
z1 = lax.pad(z, dtype(0), ((0, 1, 1),))
z2 = lax.pad(z, dtype(0), ((1, 0, 1),))
w = lax.pad(w, dtype(0), ((1, 0, 1),))
z = z1 + z2 * w
return z
It also has the advantage of most likely being a lot easier to differentiate than the current implementation.
cumsum
is very similar.
I believe this is now fixed at head. Let me know how it goes!
It works and is super fast, thank you!
Dear jax team,
I implemented
np.unwrap
by looking at the numpy source and changing the inplace-modification bits tonp.where
andnp.concatenate
.While the values are correct, the runtime explodes for
x.size > 1e7
and is generally a lot worse than the numpy version. See repro below.Now, I don't want to waste your time with some micro-optimizations for my specific code. But are there some general performance tips here? I guess the
where
andconcatenate
are problematic, but I wouldn't know how to improve this in the JAX framework. Sorry for the code being so cryptic...Repro: