I've found a behavior in which the output of jt.triton_call differs depending on when/where certain metaparameters (I suspect the metaparameters related to the grid) are defined.
Specifically, for the Triton repo's matmul kernel (source):
(1) jt.triton_callreturns a matrix of NaNs from the second call onwards (first call is correct), if the metaparams BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K are directly passed into the function call
(2) jt.triton_call returns correct results when those metaparameters are selected via triton.autotune (and not directly passed into jt.triton_call)
Also, simply importing Triton's matmul_perf_model(source) further affects this; with the import, the jt.triton_call fails (NaN outputs, as described in (1)) on the second call and beyond; if the import is commented out, then it fails on the third call and beyond.
I am attaching a script that reproduces this behavior.
I'm wondering if this is expected behavior, and if so, what jax_triton conventions I should be following regarding metaparameter/tl.constexpr passing. In general, the boundary between args and metaparams seems a bit vague; is a parameter a metaparameter if and only if it is a constexpr?
I've found a behavior in which the output of
jt.triton_call
differs depending on when/where certain metaparameters (I suspect the metaparameters related to the grid) are defined.Specifically, for the Triton repo's matmul kernel (source):
(1)
jt.triton_call
returns a matrix of NaNs from the second call onwards (first call is correct), if the metaparams BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K are directly passed into the function call (2)jt.triton_call
returns correct results when those metaparameters are selected viatriton.autotune
(and not directly passed intojt.triton_call
)Also, simply importing Triton's
matmul_perf_model
(source) further affects this; with the import, thejt.triton_call
fails (NaN outputs, as described in (1)) on the second call and beyond; if the import is commented out, then it fails on the third call and beyond.I am attaching a script that reproduces this behavior.
I'm wondering if this is expected behavior, and if so, what jax_triton conventions I should be following regarding metaparameter/tl.constexpr passing. In general, the boundary between
args
andmetaparams
seems a bit vague; is a parameter a metaparameter if and only if it is a constexpr?Thanks for the help!
matmul_repro.txt