pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
2.29k stars 170 forks source link

selective compilation - norm layers only #320

Closed lessw2020 closed 1 month ago

lessw2020 commented 4 months ago

This PR adds the option to selectively compile just the norm layers only, and is mainly targeted at RMSNorm. By compiling just the norm layers when using rmsnorm, we get nearly comparable speedups as using the fusedRMSNorm triton kernel. Credit @wconstab for this idea.

regular rmsnorm:

Screenshot 2024-05-09 at 5 09 17 PM

with the new compile_rmsnorm enabled:

Screenshot 2024-05-09 at 5 24 39 PM Screenshot 2024-05-09 at 5 09 57 PM

2 - UX - I enabled the compile rmsnorm as it's own option for now so users can quickly try whole model or norm only compile. If compile is true, then the rmsnorm layers will not specifically be compiled (as they will be included in the generic full model compile) and a minor note is issued in logging.

3 - using other norms with this option enabled does not appear to add any speedup (but also no errors) so I did not add a check to only compile if norm is rmsnorm (but can add that).

drisspg commented 4 months ago

Is 2 saying that in order to have "full" compile you need to set both compile=true and compile_rmsnorm = true

lessw2020 commented 4 months ago

Is 2 saying that in order to have "full" compile you need to set both compile=true and compile_rmsnorm = true

I updated the text to be more specific, but no - if compile = true in the config, then you get full compile including the rmsnorm layers.

tianyu-l commented 1 month ago

close as we removed the feature in #535