JonasGeiping / cramming

Cramming the training of a (BERT-type) language model into limited compute.
MIT License
1.29k stars 100 forks source link

Add Support for `torch.compile` During Pretraining for ~15% Speedup #19

Closed warner-benjamin closed 1 year ago

warner-benjamin commented 1 year ago

This PR adds support for using PyTorch 2.0's torch.compile during the pretraining task. torch.compile is set via a new argument in _default.yaml, and is turned off by default.

Using torch.compile during pretraining allows for larger batch sizes with the same model settings and when combined results in ~9% increase in throughput on a consumer GPU when compared to PyTorch 1.13.

This PR also adds PyTorchFlashMultiHeadAttention, a new attention module which uses PyTorch's scaled_dot_product_attention beta implementation of FlashAttention. scaled_dot_product_attention doesn't cause any Cuda graph breaks when compiled, leading to an additional ~5% improvement, for a total of ~15% increase in training throughput over PyTorch 1.13's eager mode with FlashAttention.

There are still multiple graph breaks in ScriptableLMForPreTraining which this PR does not attempt to resolve. Here is the output from torch._dynamo.explain:

Dynamo produced 35 graphs with 34 graph break and 117 ops
 Break reasons: 

1. return_value
  File "cramming/architectures/components.py", line 45, in forward
    return self.dropout(self.norm(embeds))

2. call_function UserDefinedObjectVariable(partial) [TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable()] {}
  File "cramming/architectures/components.py", line 137, in forward
    states = self.fn_training(states, self.attn(self.norm1(states), attention_mask), self.alpha, res_scale)

3. return_value
  File "cramming/architectures/fused_layers.py", line 47, in simplified_layer_training
    return simplified_layer_structure(states, outputs, alpha, residual_scale, prob, training=True)

4. call_function UserDefinedObjectVariable(partial) [TensorVariable(), TensorVariable(), TensorVariable(), TensorVariable()] {}
  File "cramming/architectures/components.py", line 138, in <graph break in forward>
    states = self.fn_training(states, self.ffn(self.norm2(states)), self.alpha, res_scale)

5. call_function BuiltinVariable(dict) [] {'loss': TensorVariable(), 'logits': TensorVariable()}
  File "cramming/architectures/scriptable_bert.py", line 201, in <graph break in forward>
    return dict(loss=masked_lm_loss, logits=outputs)

This PR also does not add support torch.compile for the evaluation script. Currently torch.compile appears to recompile the evaluation model every training step. I have not looked into why this is occurring.

I have only tested this PR in a single GPU setting.

If you want, I can also update the ReadMe with torch.compile instructions.

JonasGeiping commented 1 year ago

Hi, did you see the torch.2.1 branch? That branch is preparing for a 2.0 release.

I have a more advanced version of the 2.0 model in a private branch that compiles without graph breaks (and is even faster with modifications of the inductor settings to increase speed). I'm a bit pre-occupied the next 1-2 weeks, but I'll try to push this out to github after that.

So while, this is a great contribution, I doesn't make so much sense for me to merge this now and then overwrite it with a different implementation in two weeks. I'll leave the pull request open until then, so people who are interested in the 2.0 version can already find it here.

warner-benjamin commented 1 year ago

Makes sense. Looking forward to switching my experiments to the faster advanced version 2.0 model.

JonasGeiping commented 1 year ago

The new release is now live, see https://github.com/JonasGeiping/cramming/releases/tag/Torch2.1.

In my tests (back in April) I didn't find PyTorchFlashMultiHeadAttention to be better than the compiled attention implementation. Let me know if you ever test this again, and find something to the contrary :)