Closed warner-benjamin closed 1 year ago
Hi, did you see the torch.2.1
branch? That branch is preparing for a 2.0 release.
I have a more advanced version of the 2.0 model in a private branch that compiles without graph breaks (and is even faster with modifications of the inductor settings to increase speed). I'm a bit pre-occupied the next 1-2 weeks, but I'll try to push this out to github after that.
So while, this is a great contribution, I doesn't make so much sense for me to merge this now and then overwrite it with a different implementation in two weeks. I'll leave the pull request open until then, so people who are interested in the 2.0 version can already find it here.
Makes sense. Looking forward to switching my experiments to the faster advanced version 2.0 model.
The new release is now live, see https://github.com/JonasGeiping/cramming/releases/tag/Torch2.1.
In my tests (back in April) I didn't find PyTorchFlashMultiHeadAttention to be better than the compiled attention implementation. Let me know if you ever test this again, and find something to the contrary :)
This PR adds support for using PyTorch 2.0's
torch.compile
during the pretraining task.torch.compile
is set via a new argument in_default.yaml
, and is turned off by default.Using
torch.compile
during pretraining allows for larger batch sizes with the same model settings and when combined results in ~9% increase in throughput on a consumer GPU when compared to PyTorch 1.13.This PR also adds
PyTorchFlashMultiHeadAttention
, a new attention module which uses PyTorch'sscaled_dot_product_attention
beta implementation of FlashAttention.scaled_dot_product_attention
doesn't cause any Cuda graph breaks when compiled, leading to an additional ~5% improvement, for a total of ~15% increase in training throughput over PyTorch 1.13's eager mode with FlashAttention.There are still multiple graph breaks in
ScriptableLMForPreTraining
which this PR does not attempt to resolve. Here is the output fromtorch._dynamo.explain
:This PR also does not add support
torch.compile
for the evaluation script. Currentlytorch.compile
appears to recompile the evaluation model every training step. I have not looked into why this is occurring.I have only tested this PR in a single GPU setting.
If you want, I can also update the ReadMe with
torch.compile
instructions.