pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
716 stars 92 forks source link

[RFC] Add LayerSkip to AO #633

Open jcaip opened 1 month ago

jcaip commented 1 month ago

Tracker issue for adding LayerSkip to AO.

This is a training and inference optimization that is similar to layer-wise pruning. It's particularly interesting for LLM inference because it combines very cleanly with speculative decoding to provide up to a 1.86x speedup.

@mostafaelhoushi is interested in adding this to torchtune and is interested in upstreaming a subset of the code to ao. See here for more details. In particular, he's interested in doing this without having to alter the module definition.

This is attractive because this part of LayerSkip is not unique to LLMs and can be used for other models. (@mostafaelhoushi to fill out with relevant results).

What is being proposed:

for LayerSkip there is a training recipe and there is an inference recipe:

gau-nernst commented 1 month ago

Layer dropout during training looks like some form of Stochastic Depth. Some related implementations

A glance at LayerSkip paper suggests that they mask each sample independently in a batch. Probably need some tricks to see speedups? The torchtune PR implements it by indexing, applying the function, and writing back subset of a batch. Curious to see if the extra overhead is outweighed by less computation during training.

jcaip commented 1 month ago

Yup, the layer dropout aspect of layer skip is basically a version of stochastic depth, that's part of the reason why I'm interested in having it in AO, since a generic stochastic depth function / module would be useful outside of just LLMs.

IIRC when talking to mostafa he is faster when masking + rewriting but the speedups mostly come from the self-speculative decoding part of the technique.

@mostafaelhoushi can you share some benchmarks about the layer dropout implementation specifically when you update the issue? Thanks.

mostafaelhoushi commented 1 month ago

Sorry for the delay from my side.

Other Papers

I would like to mention other papers or models that used layer dropout (aka stochastic depth):

Other Implementations

Benchmark Results

On TorchTune, I ran this command on a single A100 GPU

$ tune run --nproc_per_node 1 full_finetune_distributed --config llama3/8B_full output_dir=$CKPT_PATH checkpointer.checkpoint_dir=$CKPT_PATH/original checkpointer.output_dir=$CKPT_PATH tokenizer.path=$CKPT_PATH/original/tokenizer.model batch_size=16

and got these measurements:

Maximum Dropout Dropout Scale Across Layers Time to Reach 50 Iterations Speedup
None 01 min 32 sec 1x
0.2 Uniform 01 min 23 sec 1.07x
0.3 Uniform 01 min 17 sec 1.19x
0.5 Uniform 01 min 05 sec 1.42x
0.5 Linear. TBD TBD
0.2 Exponential 01 min 30 sec 1.02x
0.5 Exponential 01 min 22 sec 1.12x

I also want to tag @danthe3rd as he guided me to implement the per-sample layer dropout and he has implemented it for Dinov2.