jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.28k stars 2.78k forks source link

Fusing loops into kernels #9080

Open EelcoHoogendoorn opened 2 years ago

EelcoHoogendoorn commented 2 years ago

I am wondering what the current state of JAX is with respect to loop fusion; or what the roadmap looks like.

In the simplest example, if we have an expression a * x + b of JAX scalars, and weve vmapped over that, what we want to avoid is to have to compute intermediates of minimal compute intensity. And not because using a fused mul-add instruction instead will make that big of a difference; but because on-device (and also on modern CPUs) code like this is almost always 100% memory bandwidth bound.

Im hoping JAX already handles a trivial case like this; but I havnt dug very deep and google isnt able to dig up much documentation.

Now more generally, the pattern you often get in physics simulations for instance (as opposed to typical ML workloads), is like this.

@jax.vmap
def timestep(state):
  state = state + foo
  # rest of a humongous compute graph of mostly puny compute intensity operations goes here
  return state

From experimentation, I can tell that JAX isn't exactly going all-out on fusing entire compute graphs like these.

What id hope would be possible, is to create something like an @jax.kernel decorator. It would behave somewhat like a numba kernel, in the sense that it would support a subset of JAX functionality. Only the most basic stuff, like for-loops, and array broadcasting. Dots and sums and einsums and the like would be unrolled; certainly no dynamic allocation (duh); and no calling into other libraries, like BLAS or whatever. Just support the stuff that maps trivially onto opencl or whatever intermediate language is used.

Moreover, there would be no attempt to get clever about trying to expose internal parallelism inside the kernel. Each element of a vmap would map into a single GPU/TPU thread, that would serially crunch the entire kernel. No attempts to write super clever device code with memory tricks or whatnot; just have each thread read each input in a nicely memory-coalesced manner; so if we compiled the timestep function above as a kernel, each input would be read once from global memory, and each output would be written once to global memory. Just let vmap do all the work of exposing the parallelism. If you dont have any such parallelism to expose in your problem; though luck for you, shouldnt be using JAX then and rolling your own device kernels I suppose.

You could get fancier than that with functionality supported in such a kernel; but implementing this subset shouldn't actually be that hard. I might even be able to contribute to it somewhat, having done similar fused loop code generation things for CUDA kernels in the past. Not too hard; and I think it would address 95% of use cases. My use cases at least. Lets call it 80/20 to be conservative.

Also implementing forward mode differentiation for such kernels should be easy given the code/tools that I imagine are already on hand. (I think backwards generally does not make much sense for kernels like these, but I might be missing something; in any case there wouldn't be a usecase for the intermediate gradients in the kernel. Or if there were; just make it two kernels).

Does this sound at all coherent? Is it already on the roadmap? Or should general future compiler improvements make more specialized functionality like this obsolete? Im supposing thats true in the long run; but you know what they say about the long run... Torch has a number of initiatives along these lines going on. If something like this should get any developer time also depends on who you see as the stakeholders in this projects. I imagine that people writing chat bots or image classifiers will view this as a waste of developer time; but I bet people working on say alphafold or brax would absolutely love a feature like this.

shoyer commented 2 years ago

My impression is that XLA actually is pretty good at fusing together the "humongous compute graph of mostly puny compute intensity operations," even when vmapped. E.g., see the "equation_of_state" benchmark from https://github.com/dionhaefner/pyhpc-benchmarks

I suppose it probably depends a bit on exactly what those puny operations are. A concrete example would certainly be helpful here.

As for your proposal here, it sounds totally reasonable, but I'm not sure that XLA in particular is the right compiler technology for it. We've definitely speculated about using tools like Numba or Dex as "kernel languages" for generating efficient kernels to drop into larger JAX programs, which sounds like it fill a similar niche. From my perspective, figuring out exactly how to write this "kernel compiler" (or identifying an existing system) would probably be the first challenge to solve, before plugging it into JAX (which I would hope would be relatively straightforward using XLA's CustomCall).

EelcoHoogendoorn commented 2 years ago

Im motivated by toying with geometric algebra in JAX, and trying to decide if its best to write JAX code using 'dense' kernels, such as can be found here, or 'unrolled' approaches, such as can be found a few lines below. Note that with general products/operations in geometric algebra, one is dealing with product of n-ary tensors that are usually quite sparse. Here another example; alternatively one could write that as a contraction with a [3,3,3] tensor of mostly zeros.

From my (admittedly limited) benchmarking, the unrolled versions get absolutely BTFO by the dense products. Even on CPU, doing a [16, 16, 16, 16] full 4d multivector sandwich product, which is a tensor with only 3% nonzeros, the dense version is 3x faster than the unrolled version. That seems to indicate that the unrolled loops do not get fused by the compiler very much; if they did youd expect to see some benefit from the 30x reduction in FLOPS. For more reasonable sized products, you are easily looking at a 10-20x advantage for the dense approach. And im imagining this gap will only widen on GPU/TPU. Given that a GPU can do 100s of flops per float it can fetch form global mem, it would make sense that the sparsity and multiplying with zero a bunch isnt our problem in the first place; the real risk is writing your code in such a way that makes it fetch your data more often from global memory than is necessary.

But perhaps I should try those GPU benchmarks first? Maybe the compiler toolchain will pursue loop fusion more aggressively for those devices? (I notice that a focus on FLOPs rather than memory bandwidth is a common vestigial instinct with people who grew up tuning C code in the 90s; but modern CPUs are more like GPUs than 90s CPU in this regard anyway).

Ill try to cobble together a minimal piece of code to see if others can reproduce these differences.

Note that doing all the loop fusion for a single GA product will easily require the fusion of many dozens of terms; and then your physics logic would consist of many dozens such operations chained together. The optimal solution from a memory bandwidth pov is clearly to fuse all of these intermediates when there is a nice bottleneck in that compute graph, like state -> state; but I dont know that XLA has been tuned very much for these scenarios, as its quite different from typical ML ones.

Its quite possible that enabling such tuning in XLA would be a path of lesser resistance compared to a completely different compilation path... but I don't know. Indeed perhaps its best if we could view a hypothetical @jax.kernel as more of a compiler hint to XLA. That being said restricting yourself to a subset of regular JAX does seem necessary to me to make the most of such loop fusion opportunities.

EelcoHoogendoorn commented 2 years ago

Interesting benchmark you linked btw, and indeed a good example of a pile of low compute intensity trash thrown together. Indeed JAX is doing quite well for itself... but that doesn't mean it couldn't be doing better I suppose. There is no direct comparison to a manually fully fused kernel; and im not sure I trust any of these frameworks to have it completely figured out. Frankly I would expect bigger differences to numpy on cpu, if they did.

Seeing it compared with a cupy elementwise-kernel would be great; cant think of a way a compiler might screw up the loop fusion with that; and that should also allow me to write it in a way that I dont think the compiler can screw up emitting code with coalesced memory access. I probably should do that for my own benchmarks as well; should be trivial to transform my GA-products into cupy elementwise-kernels. Im mostly interested in getting my kernels to run optimally in JAX but I think that should serve as a nice ground-truth of what should be possible.

EelcoHoogendoorn commented 2 years ago

Hmmm; doing more benchmarking and finding some embarrassing limitations in my benchmarking code... ill take back everything I said above. Indeed definite evidence in favor of aggressive loop fusion going on. Still curious how it will end up comparing to a GPU kernel that I trust to really make the most of available compute and bandwidth though.

It makes me wonder what the alphafold people were thinking with the way they have got their quat multiplication implemented, since that is beat handily now in all possible permutations of configurations. Still only CPU testing though and I imagine they optimized for on-device compute. A TPU has a single instruction for 4x4 matrix multiplies I think... so that might be why?

All the more reason to have a nicely configurable library that will auto-generate these type of products and make their code-generation strategies configurable; because whatever way these benchmark results may be swinging, I have yet to see a subtle difference. (also compilation speed is a 10x win for the dense approach; so its nice to be able to toggle between debug/release modes; and no i wasnt missing a warmup call in my benchmark; it was another equally stupid mistake)

But good to know that JAX can infact be trusted to reason about these things fairly effectively. Still... going to work on that cupy kernel as a ground truth.