Closed Moerafaat closed 16 hours ago
Interesting. Taking a look now.
FYI, I have a solution works for it now with stmatrix
. Will upstream soon
where out dims are: [offset (size 4096), iteration (size 1)]
tensor([[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
...,
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.]], device='cuda:0',
dtype=torch.bfloat16)
tensor([[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
...,
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.],
[8192., 8192., 8192., ..., 8192., 8192., 8192.]], device='cuda:0',
dtype=torch.bfloat16)
Thanks! Really appreciate the fast reply on this and looking forward to your fix 🙏
https://github.com/triton-lang/triton/pull/5277 is a partial fix. More general fixes will be pushed next week
https://github.com/triton-lang/triton/pull/5277 is a partial fix.
Tested it and it works great! Thanks for the fast turn-around!
Marking this fixed. Thanks for the assistance!
Describe the bug
To reproduce the issue, you can run the following python test:
You will encounter the following error:
I notice that there was a similar report here https://github.com/triton-lang/triton/issues/4727 before the issue was re-opened. Interestingly, the failure actually started happening with the commit that was linked to that issue. The culprit commit is https://github.com/triton-lang/triton/commit/49266aa908d29a7029348ff480a75a3ea4d6e704
The test passes if num_warps are set to 4 instead of 8, and used to work properly before the culprit commit.
Environment details
The issue reproduces on H100 with the latest Triton main: commit 8b29bb752033c8db578bf88624a6dc91dc8f45cf