jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
354 stars 40 forks source link

Refactors triton_kernel_call_lowering to support both cuda and rocm. #290

Closed copybara-service[bot] closed 3 months ago

copybara-service[bot] commented 3 months ago

Refactors triton_kernel_call_lowering to support both cuda and rocm.

This is a rollforward of https://github.com/jax-ml/jax-triton/pull/287 with fixes.