jax-ml / jax-triton

jax-triton contains integrations between JAX and OpenAI Triton
Apache License 2.0
340 stars 37 forks source link

Inconsistent NaN results on Triton matmul kernel #185

Open karen-sy opened 1 year ago

karen-sy commented 1 year ago

I've found a behavior in which the output of jt.triton_call differs depending on when/where certain metaparameters (I suspect the metaparameters related to the grid) are defined.

Specifically, for the Triton repo's matmul kernel (source):

(1) jt.triton_call returns a matrix of NaNs from the second call onwards (first call is correct), if the metaparams BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K are directly passed into the function call (2) jt.triton_call returns correct results when those metaparameters are selected via triton.autotune (and not directly passed into jt.triton_call)

Also, simply importing Triton's matmul_perf_model (source) further affects this; with the import, the jt.triton_call fails (NaN outputs, as described in (1)) on the second call and beyond; if the import is commented out, then it fails on the third call and beyond.

I am attaching a script that reproduces this behavior.

I'm wondering if this is expected behavior, and if so, what jax_triton conventions I should be following regarding metaparameter/tl.constexpr passing. In general, the boundary between args and metaparams seems a bit vague; is a parameter a metaparameter if and only if it is a constexpr?

Thanks for the help!

matmul_repro.txt

sharadmv commented 1 year ago

This repros in Triton as well so it appears to be a Triton compiler issue