jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
340 stars 36 forks source link

ROCM updates #281

Closed rahulbatra85 closed 2 months ago

rahulbatra85 commented 3 months ago

Adds support for ROCm

rahulbatra85 commented 3 months ago

@superbobry @sharadmv @hawkinsp Hi, this PR adds ROCm support for JAX-Triton. Please review. Thanks!

rahulbatra85 commented 3 months ago

@superbobry Thanks for reviewing! Fixed most of them, but have question for one. Please provide feedback. Thanks!

rahulbatra85 commented 2 months ago

@superbobry Please review again. Thanks!

rahulbatra85 commented 2 months ago

@superbobry Please review again. Thanks!

@superbobry Please see the new changes. Thanks!

rahulbatra85 commented 2 months ago

@superbobry If all is good, then can we merge this?

hawkinsp commented 2 months ago

FYI: I'm reverting this PR because it broke some of our internal CUDA jax-triton users. Not sure why exactly. @superbobry can you PTAL?