google-research / tuning_playbook

A playbook for systematically maximizing the performance of deep learning models.
Other
26.29k stars 2.18k forks source link

Throughput of gradient accumulation #38

Closed guofei1989 closed 1 year ago

guofei1989 commented 1 year ago

As refered in https://github.com/google-research/tuning_playbook#determining-the-feasible-batch-sizes-and-estimating-training-throughput: Gradient accumulation simulates a larger batch size than the hardware can support and therefore does not provide any throughput benefits. It should generally be avoided in applied work.

Whereas,

  1. gradient accumulation can reduce the number of backward ops,which would benefit the throughput
  2. gradient accumulation is a practical method in dealing with deficiency accelerators, why should be avoided in applied work?
georgedahl commented 1 year ago

As I understand what you wrote in (1) I do not believe it is correct. To use a batch size of 2*B with gradient accumulation of micro-batch size of B, the program will do a forward pass and a backward pass on B examples, then another forward and backward pass on B examples.

Perhaps you mean that using a larger batch size through gradient accumulation can reduce the total number of training steps required. This is true, but does NOT imply that gradient accumulation would provide any speedup. Consider doubling the batch size. The maximum benefit of doubling the batch size is that the number of training steps needed gets cut in half. However, if we use gradient accumulation to simulate the batch size of 2*B with two gradient computations on batch sizes of B, this will roughly double the time taken per training step, resulting in training time that is unchanged.

Regarding (2), I don't know what deficiency you are talking about. There are cases where we need to use gradient accumulation to simulate a larger batch size to reproduce a specific result, but in applied work we wouldn't keep using it and would just use a smaller batch size for the majority of our experiments that weren't for debugging.