google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.25k stars 2.68k forks source link

jax.lax.scan for recurrent models on very long sequences (does hardware acceleration help?) #7009

Open noahtren opened 3 years ago

noahtren commented 3 years ago

I'm curious if jax.lax.scan can reliably deliver high speeds on GPU/TPU for long sequences, in the range of 10k+ samples. For example, multiple seconds of audio will be 30k+ samples. I want to know if we can expect recurrent-ish models written in JAX to be as fast or faster than CUDA RNNs.

I've been finding that hardware acceleration with my scan-based code either causes a slowdown (GPU) or makes a very minor difference in speed (TPU). For my training loop I got these results:

Device Speed
CPU 3.4 it/s
GPU 0.8 it/s
TPU 3.6 it/s

These functions represent the bulk of what my code is doing — I'd be curious to know if there's something more efficient than dynamic slicing and dynamic updates.

sample_length = 10_000
receptive_field = 50

def recurrent_forward(args, t):
  [nkey, memory, params] = args
  memory_slice = jax.lax.dynamic_slice(memory, [t], [receptive_field])
  x = model.apply({'params': params}, memory_slice, t)
  memory = jax.lax.dynamic_update_index_in_dim(memory, x, t, 0)
  return [nkey, memory, params]

def unroll_recurrent_forward(nkey, memory, params):
  [nkey, memory, params], _ = jax.lax.scan(recurrent_forward, [nkey, memory, params],
                                           np.arange(sample_length))
  return memory

@jekbradbury

noahtren commented 3 years ago

Some more info:

noahtren commented 3 years ago

Related issue: https://github.com/google/jax/issues/2491

jekbradbury commented 3 years ago

The code snippet you posted isn't self-contained enough; we'd need either your model.apply code (presumably a Flax neural net) or a representative replacement for it.

mattjj commented 3 years ago

@noahtren thanks for raising this. Let's try to sort out what's going on!

The most helpful thing would be if you could grab a profile and share it. Does that sound feasible?

(A priori, the slicing and updating you're doing looks like the right way to write things!)

noahtren commented 3 years ago

The code snippet you posted isn't self-contained enough; we'd need either your model.apply code (presumably a Flax neural net) or a representative replacement for it.

Here's a gist of the full forward pass. I refactored the earlier code, now using Flax's lifted transformations (nn.scan and nn.remat, which saves a lot of memory,) but speed performance is roughly the same as before.

You'll see that I'm doing a couple dynamic slices and one dynamic update for each iteration. Besides that, the ops are mainly MLPs and reshaping/reduce operations. The latent_code is provided by another convolutional model that runs very quickly. Interestingly, increasing num_modules (and thus the size of some MLP layers) does slow this down quite a bit, which surprised me because I imagined that it would only effect memory usage, and not speed.

Other notes: num_modules has generally been from 4 to 12 in my experiments. receptive_field has been 80 to 160. sample_length has been 640 to 1280.

noahtren commented 3 years ago

@noahtren thanks for raising this. Let's try to sort out what's going on!

The most helpful thing would be if you could grab a profile and share it. Does that sound feasible?

(A priori, the slicing and updating you're doing looks like the right way to write things!)

Thanks! I can profile this for sure, what's the easiest way to share a profile? (I could share a path on a public cloud bucket with my TensorBoard logs, which are <1GB.)