Attempting to use jax-triton nightly fails on CUDA with
File "/home/michael/.local/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1433, in jaxpr_subcomp
ans = rule(rule_ctx, *rule_inputs, **eqn.params)
File "/home/michael/.local/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 383, in triton_kernel_call_lowering
kernel, specialization = get_or_create_triton_kernel(
File "/home/michael/.local/lib/python3.10/site-packages/jax_triton/triton_lib.py", line 252, in get_or_create_triton_kernel
module = code_gen.ast_to_ttir(
TypeError: ast_to_ttir() got an unexpected keyword argument 'target'
This appears to be because requirements.txt lists triton_nightly-2.1.0.dev20230714011643, but a newer triton nightly is actually required.
(That version doesn't have a 'target' argument to ast_to_ttir).
I would around this by installing triton_nightly-2.1.0.dev20231014192330, but it would be great to have correct dependencies.
Attempting to use jax-triton nightly fails on CUDA with
This appears to be because requirements.txt lists triton_nightly-2.1.0.dev20230714011643, but a newer triton nightly is actually required. (That version doesn't have a 'target' argument to ast_to_ttir).
I would around this by installing triton_nightly-2.1.0.dev20231014192330, but it would be great to have correct dependencies.