cloneofsimo / minRF

Minimal implementation of scalable rectified flow transformers, based on SD3's approach
Apache License 2.0
218 stars 14 forks source link

Support better kernel fusion for MMDiT architecture #1

Open cloneofsimo opened 2 months ago

cloneofsimo commented 2 months ago

either torch.compile / triton, forward / backward operations got too much activations that are probably bottlenecking training. For some reason, i got about 30% speedup at 1B scale but does not seem to do better at larger scale. Either way, attaching a good fusedlinear , fusedlayernorm, fusedMHlayernorm would be very helpful.

Reason i would prefer over torch.compile is that it torch.compile with max-autotune takes entirety. :P

Some references

cloneofsimo commented 2 months ago

@sayakpaul

sayakpaul commented 1 month ago

How about using fused linear and fused layer norm from xformers?

Additionally, could you leave a training command / references to LoC for me showing how you're using torch.compile()? That would be helpful for me to verify certain things.

Also, just checking -- you're doing bf16 or fp16 training?