jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
354 stars 40 forks source link

ROCM updates #287

Closed rahulbatra85 closed 3 months ago

rahulbatra85 commented 3 months ago

Adds support for ROCm

-Added lowering rules for ROCm and CUDA instead of using runtime information to determine platform

rahulbatra85 commented 3 months ago

@hawkinsp @superbobry Please review this new PR. Thanks!

hawkinsp commented 3 months ago

Our internal presubmits complain about trailing whitespace in this file.

rahulbatra85 commented 3 months ago

@hawkinsp @superbobry Looks like we merged this through other means. Thanks for all the feedback!

Should we close this PR now?

hawkinsp commented 3 months ago

Yeah, I merged it with some fixes to make it work.