bigcode-project / Megatron-LM

Ongoing research training transformer models at scale
Other
371 stars 48 forks source link

LM Head FLOPs #78

Open Muennighoff opened 11 months ago

Muennighoff commented 11 months ago

Why are we not multiplying the LM Head flops per iteration with the checkpoint_activations_factor? https://github.com/bigcode-project/Megatron-LM/blob/bd0aaba3492b441d7f186bb1159fc21e1dcd7a72/megatron/utils.py#L253

Afaik the factor of 4 means 1 forward, 2 backward & 1 forward, where the last forward is needed for ckpt acts. Don't we also need all 4 for the LM Head? cc @RaymondLi0 @NouamaneTazi

NouamaneTazi commented 11 months ago

In selective recomputation, we only checkpoint attention, so it shouldnt affect LM Head In full recomputation, we assume that the last hidden_states are the LM Head activs, so no need for recomputing them

https://github.com/bigcode-project/Megatron-LM/blob/bd0aaba3492b441d7f186bb1159fc21e1dcd7a72/megatron/model/transformer.py#L1905

Muennighoff commented 11 months ago

In selective recomputation, we only checkpoint attention, so it shouldnt affect LM Head In full recomputation, we assume that the last hidden_states are the LM Head activs, so no need for recomputing them

https://github.com/bigcode-project/Megatron-LM/blob/bd0aaba3492b441d7f186bb1159fc21e1dcd7a72/megatron/model/transformer.py#L1905

I see but then don't we at least need to multiply it by 3 to account for the backward pass?