alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.07k stars 356 forks source link

[BUG] Inaccurate memory profiling during pipeline stage construction #684

Open zhuohan123 opened 2 years ago

zhuohan123 commented 2 years ago

When constructing pipeline stages with Alpa's auto inter-operator parallel algorithm, we need to accurately profile the memory usage of a constructed pipeline stage to determine whether a solution is valid or not. A pipeline stage can be decomposed into the following 3 parts:

  1. Forward
  2. Backward
  3. Update (apply_grad)

Given forward_layers, backward_layers, and update_layers. The current profiling workflow is:

  1. Merge forward_layers and backward_layers into a compute_part. There is a special hook between forward and backward layers that mark the variables between them as intermediate_variables. These variables need to be stored during the execution of forward_layers for the execution of the backward_layers and can only be deleted after backward_layers finishes. During pipeline execution, we need to store multiple sets of intermediate_variables because there are multiple micro-batches on-the-fly.
  2. Merge the compute_part and the update_layers.
  3. Run auto-sharding pass to shard the stage and get the sharding spec of all tensors.
  4. Decouple the sharded compute_part and update_layers.
  5. Compile and profile the compute cost of the compute_part.

Currently, we measure the following memory:

  1. peak_memory: The peak memory achieved in the compute_part.
  2. intermediate_size: The size of all intermediate_variables.
  3. initial_size: The size of input tensors to update_layers. Typically optimizer states.
  4. available_memory: The size of GPU memory.

And we calculate the maximal number of micro-batches we can store on-the-fly with: max_n_micro_batches = (available_memory - peak_memory - initial_size) / intermediate_size. (And we set max_n_succ_stages = max_n_micro_batches per 1F1B pipeline schedule). Note that this is an under-estimate of how many batches we can fit: Actually in peak_memory, there is one copy of intermediate_variables. We don’t count this copy because when profiling for peak_memory, in the backward pass, the memory reserved for intermediate_variables will be freed as the variables becomes inactive.

To fix this issue, there are two solutions:

  1. Hot fix (fast but dirty): output intermediate_variables in compute_part, which can force these variables not to be freed.
  2. Clean fix (clean but slow): Instead of profiling a single compute_part with shard parallel, profile forward_part and backward_part separately with the pipeshard runtime.

cc @ZYHowell

merrymercy commented 2 years ago

Note that this is an under-estimate of how many batches we can fit:

Do you mean over-estimation? Because we don't count intermediate_variables, the value we computed is higher than the actual value.

Another question: can we just use max_n_micro_batches = (available_memory - peak_memory - intermediate_size - initial_size) / intermediate_size? This should be a safe underestimation, but I don't know whether it is too inaccurate.

zhuohan123 commented 2 years ago

No, we count about one extra intermediate_variables, so the memory we estimate is more than the actual memory used. So it's an under-estimation of how many batches we can fit.