jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
340 stars 36 forks source link

triton requirement is out of date #238

Open ywrt opened 1 year ago

ywrt commented 1 year ago

Attempting to use jax-triton nightly fails on CUDA with

 File "/home/michael/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1433, in jaxpr_subcomp
    ans = rule(rule_ctx, *rule_inputs, **eqn.params)
  File "/home/michael/.local/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 383, in triton_kernel_call_lowering
    kernel, specialization = get_or_create_triton_kernel(
  File "/home/michael/.local/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 252, in get_or_create_triton_kernel
    module = code_gen.ast_to_ttir(
TypeError: ast_to_ttir() got an unexpected keyword argument 'target'

This appears to be because requirements.txt lists triton_nightly-2.1.0.dev20230714011643, but a newer triton nightly is actually required. (That version doesn't have a 'target' argument to ast_to_ttir).

I would around this by installing triton_nightly-2.1.0.dev20231014192330, but it would be great to have correct dependencies.

sharadmv commented 1 year ago

Thanks for investigating, I probably won't get around to this for a week or so. A PR is greatly appreciated though!