Open honnibal opened 4 years ago
In general, if the inputs to your program are padded batches, you'll be best off continuing to pass around padded/batched values and performing operations on them together (predicated if necessary). This process of padding data and masking computation shows up a lot in the NLP world and is kind of annoying to do manually, so we've worked on a prototype JAX transform for doing it automatically (called mask
but not particularly complete yet). I imagine that even without mask
you could revise thinc a bit to use fixed-size (or bucket-sized) arrays and predication under the hood rather than always using the smallest array that fits the relevant data, and perhaps also use vmap
and keep batches together rather than looping over individual samples even for non-BLASy things.
I like to claim that it isn't really static shapes, but rather batching, that causes the need for masking/padding (and so JAX being built on a compiler that needs static shapes isn't really the cause of the problem you're seeing). That claim is undermined a little bit by the fact that certain operations aren't any faster when batched, but JAX style is to use vmap
anyway because it leads to simpler code that's faster to compile than either rolled or unrolled loops.
(One way of thinking about mask
and what it does is that the JAX embedded language is statically shaped, but mask
lets you write programs with parametric shape polymorphism in certain arguments. The main class of things that mask
will get you that simple manual padding/masking won't is composition with other JAX transformations and higher-order functions, so you'd be able to use code vmap(mask(scan(...)))
for an RNN over a batch of variable-length sequences.)
I see what you're saying, but consider: Jax is 100 times slower than numpy here. I don't think it's quite right to say that this is a normal NLP situation: ordinarily the trade-off around padding is that you do gain some efficiency, perhaps up to 2x, in exchange for using more memory and perhaps having different programming problems. The rigidity of the requirement for fixed shapes means that Jax's performance characteristics are wildly different, so you really have to program quite differently for it.
As I mentioned the other day, I think the __cuda_array_interface__
will be a very important feature for Jax, because it will give users a way to solve problems that otherwise would require a huge rewrite, or prevent Jax from being viable.
Anyway. I'll think about the advice, thanks. I guess I've been a bit aggressive in avoiding padding previously, as it's usually better to keep the working set smaller for CPU performance.
These are great points! You're 100% right that this is something JAX is unacceptably bad at, and it's on us to fix.
I'd like to try out our mask
transform here, and/or XLA's new shape polymorphic support. (Maybe this week, if no fires pop up...) We brainstormed a bit about this issue together with some XLA devs in our chat room earlier today. We have some promising leads but nothing ready yet.
+1 to this being a great motivation for __cuda_array_interface__
. We've started some work for that.
@mattjj Great to hear the __cuda_array_interface__
is going ahead! I'd be glad to test it once you have a dev build.
Jax will work especially well with the cupy library, due to the near identical numpy interface. Cupy also lets you solve the memory contention problem that you otherwise get on GPU devices.
cupy allows you to set a custom allocator function, and has an UnownedMemory
type. You can therefore prevent cupy from requesting memory directly from the device, and instead have it route its memory requests via XLA. This eliminates the problem of having two memory pools. Example for cupy+PyTorch: https://github.com/explosion/thinc/blob/develop/thinc/backends/_cupy_allocators.py#L39
When the __cuda_array_interface__
is ready, you might consider including a little cupy allocator function like this as well.
@mattjj can I ask if the mask
transform was ever publicly available? It still seems relevant, and you've mentioned it in a couple of places where boolean indexing or value-dependent shapes have come up (e.g., https://github.com/google/jax/issues/5013#issuecomment-757386620).
I've been stumbling into a lot of performance problems where simple steps I hadn't paid much attention to overwhelm the runtime. I guess it's because I end up triggering a lot of compilation. I'm hoping for some advice on how to approach this.
I made a small example to illustrate the kind of case I'm hitting. I've also attached a flame-graph from the py-spy profiler.
The benchmark generates a bunch of randomly sized arrays, and then selects rows from them using random indices. Results on CPU with 1000 samples (in seconds of runtime, lower is better):
The results refer to CPU execution, but the problem is if anything worse on GPU.
To interpret the flame-graph, look at the width of each block (the ordering and colours aren't significant). Children of the block are resulting calls from it, with their width proportionally to the runtime. For a program like Jax which is Python on top of native calls, we're basically interested in three things: 1) how much "leakage" is there to general Python inefficiency? This shows up as narrowing blocks down the call-stack, as % time is lost to overhead. 2) What are the leaves, i.e. where do we end up spending our time? 3) What top-level calls resulted in those leaf calls?
The profile shows that there's little leakage to Jax's Python code, and we end up spending all our time compiling, triggered within calls to
_rewriting_take
.I understand this is sort of a worst case for Jax, but my problem is that analogous situations are really coming up for me (I'm writing neural network code for NLP). Even if the inputs to my program are padded batches, at various points my program might be passing around irregularly-sized chunks of data, e.g. due to conditional indexing operations.
Is there anything I should be doing differently?