jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
339 stars 36 forks source link

Question about Purpose of Pallas #124

Open adam-hartshorne opened 1 year ago

adam-hartshorne commented 1 year ago

If I understand it correctly, the idea of Pallas is to provide a level of abstraction from Triton, enabling one to define a kernel using JAX functions.

Obviously, Triton doesn't provide any AutoDiff functionality, it is just a way of interacting with the GPU memory in a more user-friendly way than CUDA (and ultimately plan is enable it to also work on non-Nvidia hardware).

Is the idea of Pallas to seamlessly provide AutoDiff as well, so that any kernels defined will come with the ability to take gradients (in the way you can with normal JAX functions)?

wangkuiyi commented 1 year ago

I have exactly the same question. cc. @zhangqiaorjc

Could we have an example that trains a linear regression using jax-triton? Is this a good idea to show (1) the candidate of primitive ops and (2) how does the auto diff work.

sharadmv commented 1 year ago

I'd say compatibility with JAX transformations is one of the major aspects of Pallas. Vmapping pallas kernels, in my opinion, is really convenient. AD of Pallas on the other hand, has more uncertain value. The current implementation is not complete, but even when it is done, differentiating kernels won't necessarily produce efficient backwards pass kernels.

Another aspect is that I think pallas is a more friendly front-end than Triton's AST based one but that's mostly personal preference.

Could we have an example that trains a linear regression using jax-triton? Is this a good idea to show (1) the candidate of primitive ops and (2) how does the auto diff work.

That sounds like a good thing to have but I won't have the bandwidth to implement it for some time. Maybe it'll be a good first project for someone!