Open davidberard98 opened 1 day ago
@bertmaher https://github.com/davidberard98/triton/tree/revert-4582 is the revert commit for convenience, since the revert had a conflict. I verified that it fixes this issue, but I haven't run any broader inductor/triton tests.
cc @ThomasRaoux since #4582 is yours, and we're a bit stumped as to whether there's an issue with that PR or if this is just uncovering a ptxas bug
huh yes I'm very surprised this would cause such a failure.
Confusingly, when we set init=false here,
This does suggest a problem in ptxas as this init is there only to help ptxas figure the liverness of this register.
ptxas optimization level of 0 or 1 fixes the issue.
I guess that's another hint in this direction.
Should we file a bug against ptxas? Not sure what else to do here
Should we file a bug against ptxas? Not sure what else to do here
Yeah, I think @davidberard98 has filed an nvbug for ptxas (is there a shareable link for this one?).
I'm thinking we should revert the PR for the 3.2 release since it causes a few failures for PyTorch, but leave main
alone and just see what happens with ptxas. How does that sound?
Reverted on the rc/3.2.x branch as https://github.com/triton-lang/triton/commit/35c6c7c6284582b3f41c71c150e11b517acf074a
Describe the bug
pytorch reference: https://github.com/pytorch/pytorch/issues/141079
repro:
Details:
addmm(a, b, bias) = a @ b + bias
. a is [512, 33], b is [33, 34], and bias is [34]. The actual output differs from the expected output by not containing the added bias, but only in certain indices. Specifically, the bias is missing whererow_idx % 64
is in[16, 24)
and wherecol_idx % 4 == 1
.init=false
here, the issue goes away. (This isn't a viable workaround and it also doesn't make sense why uninitializing the registers would fix anything)tail call void @llvm.nvvm.barrier0(), !dbg !34
in the LLIR before the firstld.global.L1::evict_last.v2.b32
, the issue also goes away.We think this could be a ptxas bug.
For triton 3.2, we're considering reverting https://github.com/triton-lang/triton/pull/4582 in the release branch to work around the issue (cc @bertmaher)
cc @embg who helped with debugging.
Environment details
Triton: main branch nov 19 GPU: A100