cubed-dev / cubed

Bounded-memory serverless distributed N-dimensional array processing
https://cubed-dev.github.io/cubed/
Apache License 2.0
97 stars 7 forks source link

Is there a parallel between tile GPU/TPU kernes and Cubed chunks? #490

Open alxmrs opened 1 week ago

alxmrs commented 1 week ago

Tile based operations have been quite a success for creating optimal GPU kernels. The programming model, in my understanding, offers flexibility while taking advantage of cache hierarchies.

http://www.eecs.harvard.edu/~htk/publication/2019-mapl-tillet-kung-cox.pdf

The triton language takes advantage of this model by providing a sort of MLIR/LLVM middleware for custom kernel acceleration of specific NN ops. Jax even now offers its own portable version of kennel control with time semantics via Pallas.

https://jax.readthedocs.io/en/latest/pallas/index.html

I can’t help but think that there are parallels between Cubed’s chunked blockwise op and these tile based techniques. What could an intersection look like?

alxmrs commented 1 week ago

I believe that Cubed chunks are “macro tiles” within the tile hierarchy.

IMG_5639

tomwhite commented 1 week ago

Very interesting - thanks for the pointers Alex!

I believe that Cubed chunks are “macro tiles” within the tile hierarchy.

To be clear, do you mean that a Cubed chunk would be composed of multiple Triton tiles?

alxmrs commented 1 week ago

I was taking a bit of poetic license. :)

I think that Cubed chunks, which live in RAM in userspace, could be considered part of the overall memory hierarchy for accelerated computation in the triton model. I do think that they provide natural affordances for efficient kernel construction that can be automated by Jax via Pallas or Triton.