pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.29k stars 115 forks source link

[fused_rmsnorm] Register as a custom operator for tracing #303

Closed wconstab closed 3 weeks ago

wconstab commented 2 months ago

Stack from ghstack (oldest at bottom):

This just refactors the fused_rmsnorm kernel into torch_library functions so export tracing can avoid tracing inside the kernel which has several tracing-unfriendly things including dynamic stride usage

wconstab commented 2 months ago

It IMA's @tianyu-l also had a Pr for this but I didn't know about it :/ hopefully figure out soon.

kwen2501 commented 2 months ago

What is "IMA" short for?

lessw2020 commented 2 months ago

What is "IMA" short for?

Illegal Memory Access - the generic cuda error that something has exceeded it's memory index.

kwen2501 commented 1 month ago

Thanks @lessw2020 . Do you think the IMA relates to the triton kernel? Can you help fix it? PP needs this fix to land. Would appreciate your help.

lessw2020 commented 1 month ago

Thanks @lessw2020 . Do you think the IMA relates to the triton kernel? Can you help fix it? PP needs this fix to land. Would appreciate your help.

Hi @kwen2501 - I'm debugging into this. It's not an issue with the kernel per se. Rather, for some reason when the kernel is run as a registered op then the triton masking is being randomly polluted with values that exceed the CUDA memory addressable space and this causes the IMA.

[rank0]:pid (18, 0, 0) idx (150) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (151) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (152) x_hat: -174770156674672237865863089087886393344.000000
[rank0]:pid (18, 0, 0) idx (153) x_hat: 0.000000
[rank0]:pid (18, 0, 0) idx (154) x_hat: 0.000000

The above should be all 0's b/c it's where we have no input data...but somehow this ginormous number and others like it are randomly being added into the masked off input values. This is what causes the IMA.

For reference, a normal value for inputs (where we have values):

[rank0]:pid (2, 0, 0) idx ( 36) x: 0.673390
[rank0]:pid (2, 0, 0) idx ( 37) x: 0.314899
[rank0]:pid (2, 0, 0) idx ( 38) x: 0.522899
[rank0]:pid (2, 0, 0) idx ( 39) x: -0.126250
[rank0]:pid (2, 0, 0) idx ( 40) x: -0.819483

Will continue investigating but this issue is likely some kind of bug between triton load masking and what is going awry when run as a custom op and not a kernel specific issue.

kwen2501 commented 1 month ago

Thanks @lessw2020 for the demonstration.

some kind of bug between triton load masking and what is going awry when run as a custom op

Can you point me to the code where triton load masking is done?

Also cc @tugsbayasgalan @zou3519

lessw2020 commented 1 month ago

Hi @kwen2501 - sure, here's the specific line that has the issue. https://github.com/pytorch/torchtitan/blob/f72a2a0da0bdfc394faaab9b3c0f35d0b6f5be50/torchtitan/models/norms.py#L198 That is loading the inputs and masking off any values past the known col length and should set those to zero. However, some of those zeros are being randomly polluted.

wconstab commented 3 weeks ago

@lessw2020 i was just cleaning up old PRs and closing many, is this one still in the running or should we abandon it? feel free to reopen