Open zhuohan123 opened 2 years ago
Note that this is an under-estimate of how many batches we can fit:
Do you mean over-estimation? Because we don't count intermediate_variables
, the value we computed is higher than the actual value.
Another question: can we just use max_n_micro_batches = (available_memory - peak_memory - intermediate_size - initial_size) / intermediate_size
? This should be a safe underestimation, but I don't know whether it is too inaccurate.
No, we count about one extra intermediate_variables
, so the memory we estimate is more than the actual memory used. So it's an under-estimation of how many batches we can fit.
When constructing pipeline stages with Alpa's auto inter-operator parallel algorithm, we need to accurately profile the memory usage of a constructed pipeline stage to determine whether a solution is valid or not. A pipeline stage can be decomposed into the following 3 parts:
apply_grad
)Given
forward_layers
,backward_layers
, andupdate_layers
. The current profiling workflow is:forward_layers
andbackward_layers
into acompute_part
. There is a special hook between forward and backward layers that mark the variables between them asintermediate_variables
. These variables need to be stored during the execution offorward_layers
for the execution of thebackward_layers
and can only be deleted afterbackward_layers
finishes. During pipeline execution, we need to store multiple sets ofintermediate_variables
because there are multiple micro-batches on-the-fly.compute_part
and theupdate_layers
.compute_part
andupdate_layers
.compute_part
.Currently, we measure the following memory:
peak_memory
: The peak memory achieved in thecompute_part
.intermediate_size
: The size of allintermediate_variables
.initial_size
: The size of input tensors toupdate_layers
. Typically optimizer states.available_memory
: The size of GPU memory.And we calculate the maximal number of micro-batches we can store on-the-fly with:
max_n_micro_batches = (available_memory - peak_memory - initial_size) / intermediate_size
. (And we setmax_n_succ_stages = max_n_micro_batches
per 1F1B pipeline schedule). Note that this is an under-estimate of how many batches we can fit: Actually inpeak_memory
, there is one copy ofintermediate_variables
. We don’t count this copy because when profiling forpeak_memory
, in the backward pass, the memory reserved forintermediate_variables
will be freed as the variables becomes inactive.To fix this issue, there are two solutions:
intermediate_variables
incompute_part
, which can force these variables not to be freed.compute_part
with shard parallel, profileforward_part
andbackward_part
separately with the pipeshard runtime.cc @ZYHowell