Open noahtren opened 3 years ago
Some more info:
jax.random.split
inside each recurrent step actually increased the speed 50%tf.function
and tf.while_loop
, and it is still slower than my model in JAX. This makes me think that custom recurrent loops are just hard to write in general?Related issue: https://github.com/google/jax/issues/2491
The code snippet you posted isn't self-contained enough; we'd need either your model.apply
code (presumably a Flax neural net) or a representative replacement for it.
@noahtren thanks for raising this. Let's try to sort out what's going on!
The most helpful thing would be if you could grab a profile and share it. Does that sound feasible?
(A priori, the slicing and updating you're doing looks like the right way to write things!)
The code snippet you posted isn't self-contained enough; we'd need either your
model.apply
code (presumably a Flax neural net) or a representative replacement for it.
Here's a gist of the full forward pass. I refactored the earlier code, now using Flax's lifted transformations (nn.scan
and nn.remat
, which saves a lot of memory,) but speed performance is roughly the same as before.
You'll see that I'm doing a couple dynamic slices and one dynamic update for each iteration. Besides that, the ops are mainly MLPs and reshaping/reduce operations. The latent_code
is provided by another convolutional model that runs very quickly. Interestingly, increasing num_modules
(and thus the size of some MLP layers) does slow this down quite a bit, which surprised me because I imagined that it would only effect memory usage, and not speed.
Other notes: num_modules
has generally been from 4 to 12 in my experiments. receptive_field
has been 80 to 160. sample_length
has been 640 to 1280.
@noahtren thanks for raising this. Let's try to sort out what's going on!
The most helpful thing would be if you could grab a profile and share it. Does that sound feasible?
(A priori, the slicing and updating you're doing looks like the right way to write things!)
Thanks! I can profile this for sure, what's the easiest way to share a profile? (I could share a path on a public cloud bucket with my TensorBoard logs, which are <1GB.)
I'm curious if
jax.lax.scan
can reliably deliver high speeds on GPU/TPU for long sequences, in the range of 10k+ samples. For example, multiple seconds of audio will be 30k+ samples. I want to know if we can expect recurrent-ish models written in JAX to be as fast or faster than CUDA RNNs.I've been finding that hardware acceleration with my
scan
-based code either causes a slowdown (GPU) or makes a very minor difference in speed (TPU). For my training loop I got these results:These functions represent the bulk of what my code is doing — I'd be curious to know if there's something more efficient than dynamic slicing and dynamic updates.
@jekbradbury