Closed wconstab closed 3 weeks ago
It IMA's @tianyu-l also had a Pr for this but I didn't know about it :/ hopefully figure out soon.
What is "IMA" short for?
What is "IMA" short for?
Illegal Memory Access - the generic cuda error that something has exceeded it's memory index.
Thanks @lessw2020 . Do you think the IMA relates to the triton kernel? Can you help fix it? PP needs this fix to land. Would appreciate your help.
Thanks @lessw2020 . Do you think the IMA relates to the triton kernel? Can you help fix it? PP needs this fix to land. Would appreciate your help.
Hi @kwen2501 - I'm debugging into this. It's not an issue with the kernel per se. Rather, for some reason when the kernel is run as a registered op then the triton masking is being randomly polluted with values that exceed the CUDA memory addressable space and this causes the IMA.
[rank0]:pid (18, 0, 0) idx (150) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (151) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (152) x_hat: -174770156674672237865863089087886393344.000000
[rank0]:pid (18, 0, 0) idx (153) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (154) x_hat: 0.000000
The above should be all 0's b/c it's where we have no input data...but somehow this ginormous number and others like it are randomly being added into the masked off input values. This is what causes the IMA.
For reference, a normal value for inputs (where we have values):
[rank0]:pid (2, 0, 0) idx ( 36) x: 0.673390
[rank0]:pid (2, 0, 0) idx ( 37) x: 0.314899
[rank0]:pid (2, 0, 0) idx ( 38) x: 0.522899
[rank0]:pid (2, 0, 0) idx ( 39) x: -0.126250
[rank0]:pid (2, 0, 0) idx ( 40) x: -0.819483
Will continue investigating but this issue is likely some kind of bug between triton load masking and what is going awry when run as a custom op and not a kernel specific issue.
Thanks @lessw2020 for the demonstration.
some kind of bug between triton load masking and what is going awry when run as a custom op
Can you point me to the code where triton load masking is done?
Also cc @tugsbayasgalan @zou3519
Hi @kwen2501 - sure, here's the specific line that has the issue. https://github.com/pytorch/torchtitan/blob/f72a2a0da0bdfc394faaab9b3c0f35d0b6f5be50/torchtitan/models/norms.py#L198 That is loading the inputs and masking off any values past the known col length and should set those to zero. However, some of those zeros are being randomly polluted.
@lessw2020 i was just cleaning up old PRs and closing many, is this one still in the running or should we abandon it? feel free to reopen
Stack from ghstack (oldest at bottom):
161
This just refactors the fused_rmsnorm kernel into torch_library functions so export tracing can avoid tracing inside the kernel which has several tracing-unfriendly things including dynamic stride usage