google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

Unexpected Overheads with Activation Checkpointing with Pipeline Parallelism #17

Open abhinavgoel95 opened 1 year ago

abhinavgoel95 commented 1 year ago

We notice a buggy behavior with bitcasts and dynamic update slices. When we turn on activation checkpointing (e.g., saving outputs of projection layers using the SAVE_OUT_PROJ flag in PAXML) we see multiple extra updates and copies.

For example, we want to checkpoint an activation of shape [2,2048,48,128]. However, in the HLO below we see that the copies are of shape [15,1,2,2048,48,128]. Here, 15 is the number of microbatches we are using with pipeline parallelism.

Snippet of HLO:

fusion.549 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, ..., kind=kLoop, calls=fused_computation.549, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
get-tuple-element.5874 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=0
copy.583 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5874)
get-tuple-element.5866 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=1
copy.575 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5866)
get-tuple-element.5868 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=2
copy.577 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5868)
get-tuple-element.5870 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=3
copy.579 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5870)
get-tuple-element.5872 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} get-tuple-element(fusion.549), index=4
copy.581 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} copy(get-tuple-element.5872)

...

fused_computation.549 {
  param_1.8511 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(1)
  bitcast.52601 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_1.8511)
  param_0.6313 = bf16[2,48,128,2048]{3,2,1,0} parameter(0)
  bitcast.52600 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_0.6313)
  param_2.5901 = s32[] parameter(2)
  constant_7564 = s32[] constant(0)
  compare.3477 = pred[] compare(param_2.5901, constant_7564), direction=LT, metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/pipeline._scan_fn/pipeline._get_iteration_inputs/jit(remainder)/rem" source_file="/pax/praxis/praxis/layers/pipeline.py" source_line=422}
  constant_11524 = s32[] constant(15)
  add.6580 = s32[] add(param_2.5901, constant_11524), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/add" source_file="/pax/praxis/praxis/base_layer.py" source_line=695}
  select.5360 = s32[] select(compare.3477, add.6580, param_2.5901), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/select_n" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  dynamic-update-slice.325 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52601, bitcast.52600, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  bitcast.52599 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} bitcast(dynamic-update-slice.325), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  param_4.7770 = bf16[15,1,2,2048,48,128]{3,5,4,2,1,0} parameter(4)
  bitcast.52617.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_4.7770)
  param_3.8428 = bf16[2,48,128,2048]{3,2,1,0} parameter(3)
  bitcast.52616.clone.1 = bf16[1,1,2,48,128,2048]{5,4,3,2,1,0} bitcast(param_3.8428)
  dynamic-update-slice.333.clone.1 = bf16[15,1,2,48,128,2048]{5,4,3,2,1,0} dynamic-update-slice(bitcast.52617.clone.1, bitcast.52616.clone.1, select.5360, constant_7564, constant_7564, /*index=5*/constant_7564, constant_7564, constant_7564), metadata={op_name="pjit(_wrapped_step_fn)/jit(main)/jvp(xformer_lm.apply)/xformer_lm/xformer_lm.compute_predictions/lm/transformer/pipeline/while/body/dynamic_update_slice" source_file="/usr/local/lib/python3.8/dist-packages/flax/core/axes_scan.py" source_line=148}
  ...
  ROOT tuple.356 = (bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}, bf16[15,1,2,2048,48,128]{3,5,4,2,1,0}) tuple(bitcast.52599, bitcast.52615.clone.1, bitcast.52611.clone.1, bitcast.52607.clone.1, bitcast.52603.clone.1)
}

It seems like there is a big buffer of size [15,1,2,2048,48,128] holding the activations for all microbatches. Within each microbatch, we are trying to update one row of this buffer (of shape [2,2048,48,128]). But XLA loads the entire buffer into memory, performs the update, and then copies the buffer back. We see this problem in our profiles. The amount of time spent on D2D copies (i.e., copy.575 to copy.583) is much larger than expected for the amount of data that should be copied. Right now, the time spent on activation checkpointing is 5% to 8% of the overall run time for a GPT-3 style model.

Our current understanding: The reason for the copy is because when bitcast is treated as computing a new value (e.g., like a convert or sqrt), then a new tensor must be used in each loop iteration, therefore a copy of each DUS result must be made. This should be able to be fixed by treating bitcast as an aliasing operation instead of computing a new value --- in the dataflow analysis. I think there is an option in dataflow analysis that configures how bitcast should be treated. In XLA TPU, the option is set to be true where bitcasts are treated as simply an aliasing operation.

Would someone be able to look into this?

I am attaching a link to the HLO: https://drive.google.com/drive/folders/1fYUsqfDgYRRpgOklE-k7qx_5ixkJzKPD?usp=sharing

akuegel commented 1 year ago

The option is set in the same way for XLA GPU (both TPU and GPU use the default value for this flag). So it is not so easy to fix it like that, and there was some doubt whether it can be fixed apart from avoiding the reshape bitcast (which might be added on model side).

Quoting from the chat: "A reshape of a dus is not longer the same tensor. A copy might be needed. Not always though"