Open Muennighoff opened 1 year ago
In selective recomputation, we only checkpoint attention, so it shouldnt affect LM Head In full recomputation, we assume that the last hidden_states are the LM Head activs, so no need for recomputing them
In selective recomputation, we only checkpoint attention, so it shouldnt affect LM Head In full recomputation, we assume that the last hidden_states are the LM Head activs, so no need for recomputing them
I see but then don't we at least need to multiply it by 3 to account for the backward pass?
Why are we not multiplying the LM Head flops per iteration with the
checkpoint_activations_factor
? https://github.com/bigcode-project/Megatron-LM/blob/bd0aaba3492b441d7f186bb1159fc21e1dcd7a72/megatron/utils.py#L253Afaik the factor of 4 means 1 forward, 2 backward & 1 forward, where the last forward is needed for ckpt acts. Don't we also need all 4 for the LM Head? cc @RaymondLi0 @NouamaneTazi