Open cloneofsimo opened 6 months ago
@sayakpaul
How about using fused linear and fused layer norm from xformers
?
Additionally, could you leave a training command / references to LoC for me showing how you're using torch.compile()
? That would be helpful for me to verify certain things.
Also, just checking -- you're doing bf16 or fp16 training?
either torch.compile / triton, forward / backward operations got too much activations that are probably bottlenecking training. For some reason, i got about 30% speedup at 1B scale but does not seem to do better at larger scale. Either way, attaching a good fusedlinear , fusedlayernorm, fusedMHlayernorm would be very helpful.
Reason i would prefer over torch.compile is that it torch.compile with max-autotune takes entirety. :P
Some references
https://github.com/crowsonkb/k-diffusion/blob/21d12c91ad4550e8fcf3308ff9fe7116b3f19a08/k_diffusion/models/image_transformer_v2.py#L90
trident has lot of them implemented? https://github.com/kakaobrain/trident
unsloth has many backwards implemented https://github.com/unslothai/unsloth/blob/main/unsloth/kernels/geglu.py
Triton code extraction https://youtu.be/LuhJEEJQgUM?t=2234