jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.53k stars 2.8k forks source link

np.unwrap runtime explodes #2418

Closed clemisch closed 4 years ago

clemisch commented 4 years ago

Dear jax team,

I implemented np.unwrap by looking at the numpy source and changing the inplace-modification bits to np.where and np.concatenate.

While the values are correct, the runtime explodes for x.size > 1e7 and is generally a lot worse than the numpy version. See repro below.

Now, I don't want to waste your time with some micro-optimizations for my specific code. But are there some general performance tips here? I guess the where and concatenate are problematic, but I wouldn't know how to improve this in the JAX framework. Sorry for the code being so cryptic...

Repro:

import jax
import jax.numpy as np
import numpy as onp

@jax.partial(jax.jit, static_argnums=1)
def unwrap(p, axis):
    nd = np.ndim(p)
    dd = np.diff(p, axis=axis)

    ddmod = np.mod(dd + np.pi, 2 * np.pi) - np.pi
    ddmod = np.where(
        np.isclose(ddmod, -np.pi) & (dd > 0),
        np.pi,
        ddmod)

    ph_correct = np.where(
        np.abs(dd) < np.pi,
        0,
        ddmod - dd)

    up = np.concatenate((
        jax.lax.slice_in_dim(p, 0, 1, axis=axis),
        jax.lax.slice_in_dim(p, 1, None, axis=axis) + np.cumsum(ph_correct, axis=axis)
    ), axis=axis)

    return up

# OKAY
x = onp.random.randn(1000) * 10
assert onp.allclose(
    onp.unwrap(x),
    unwrap(x, 0),
    atol=1e-3, rtol=1e-3
)

# NOT OKAY
x = onp.random.randn(int(1e7)) * 10
unwrap(x, 0).block_until_ready()
gnecula commented 4 years ago

I think that the problem is in np.cumsum. The following:

big_x = onp.random.randn(size)
np.cumsum(big_x, axis=0)

is much slower than onp.cumsum. I have checked on CPU for now. @hawkinsp do you have any advice of what I should try next?

hawkinsp commented 4 years ago

I think the issue is that the algorithm implemented by cumsum is quadratic time if XLA implements it naively: https://github.com/google/jax/blob/6b157ff91cd9b0030e62b43e857fcecc32cfdf8b/jax/numpy/lax_numpy.py#L1544

I think only the TPU does something smarter than the naive implementation, so it's unsurprising this takes forever with an input size of 1e7 on CPU and GPU.

My personal temptation here would be to try implementing a Blelloch-style parallel sum scan algorithm using gather and scatter-add. There's a good blog post explaining them here: https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda

hawkinsp commented 4 years ago

I'm also curious what happens if you autodiff the Blelloch algorithm. Currently cumprod isn't arbitrarily differentiable, but I don't see any fundamental reason one cannot differentiate through the Blelloch algorithm at an 2n log(n) space cost, which doesn't seem unreasonable for reverse-mode autodiff.

clemisch commented 4 years ago

Thank you for your insight and the link to the algorithm!

hawkinsp commented 4 years ago

Here's a version of cumprod that has asymptotically better complexity although only works for power of 2 array sizes:

@jax.jit
def cumprod_v2(z):
  n = len(z)
  log2 = int(math.log2(n - 1))
  zs = []
  for d in range(0, log2):
    z1 = lax.slice(z, (0,), (len(z),), (2,))
    zs.append(z1)
    z2 = lax.slice(z, (1,), (len(z),), (2,))
    z = z1 * z2
  zs.append(lax.slice(z, (0,), (len(z),), (2,)))

  dtype = jnp.dtype(z.dtype).type
  z = jnp.array([1], dtype=dtype)
  for w in reversed(zs):
    z1 = lax.pad(z, dtype(0), ((0, 1, 1),))
    z2 = lax.pad(z, dtype(0), ((1, 0, 1),))
    w = lax.pad(w, dtype(0), ((1, 0, 1),))
    z = z1 + z2 * w
  return z

It also has the advantage of most likely being a lot easier to differentiate than the current implementation.

cumsum is very similar.

hawkinsp commented 4 years ago

I believe this is now fixed at head. Let me know how it goes!

clemisch commented 4 years ago

It works and is super fast, thank you!