This PR adds compatibility for BitFit. I'd like to try BitFit + MTF to retain Multilinguality.
Empirical evidence from this paper:
Note that adapters also add parameters to the model & increase complexity at inference in Transformers, so BF is the best option imo.
Also see this paper though they don't try BitFit.
Automatic Tests: Happy to add one if we decide to merge this 🤗
Manual Tests:
1 Nodes, PP=2, TP=2
2 Nodes, PP=2, TP=2
The below shows how the grad norm decreases as it should, because we have less gradients.
I would also expect time to decrease due to less communication, but probably only at more nodes.
Memory usage also decreases due to less optimizer states to store.
With BitFit, 2 Nodes, PP=2, TP=2
[default3]: iteration 2/ 868457 | consumed samples: 384 | consumed tokens: 786432 | elapsed time per iteration (s): 12.86 | learning rate: 6.291E-07 | global batch size: 192 | lm loss: 1.244176E+01 | loss scale: 4096.0 | grad norm: 0.065 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 | samples per second: 14.925 | TFLOPs: 13.65 |
[default3]: iteration 3/ 868457 | consumed samples: 576 | consumed tokens: 1179648 | elapsed time per iteration (s): 12.62 | learning rate: 9.437E-07 | global batch size: 192 | lm loss: 1.244014E+01 | loss scale: 4096.0 | grad norm: 0.062 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 | samples per second: 15.209 | TFLOPs: 13.91 |
Without BitFit
[default3]: iteration 2/ 868457 | consumed samples: 384 | consumed tokens: 786432 | elapsed time per iteration (s): 12.62 | learning rate: 6.291E-07 | global batch size: 192 | lm loss: 1.244176E+01 | loss scale: 4096.0 | grad norm: 0.291 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 | samples per second: 15.214 | TFLOPs: 13.91 |
[default3]: iteration 3/ 868457 | consumed samples: 576 | consumed tokens: 1179648 | elapsed time per iteration (s): 12.63 | learning rate: 9.437E-07 | global batch size: 192 | lm loss: 1.244006E+01 | loss scale: 4096.0 | grad norm: 0.309 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 | samples per second: 15.201 | TFLOPs: 13.90 |
This PR adds compatibility for BitFit. I'd like to try BitFit + MTF to retain Multilinguality. Empirical evidence from this paper:
Note that adapters also add parameters to the model & increase complexity at inference in Transformers, so BF is the best option imo. Also see this paper though they don't try BitFit.
Automatic Tests: Happy to add one if we decide to merge this 🤗
Manual Tests:
The below shows how the grad norm decreases as it should, because we have less gradients. I would also expect time to decrease due to less communication, but probably only at more nodes. Memory usage also decreases due to less optimizer states to store.
With BitFit, 2 Nodes, PP=2, TP=2
Without BitFit