rycolab / differentiable-subset-pruning

15 stars 1 forks source link

differentiable-subset-pruning

This repository is in accompany with the paper: Differentiable Subset Pruning of Transformer Heads.

Dependencies

BERT on MNLI

Install the Transformer library adapted from Huggingface:

cd transformers
pip install -e .

Then install other dependencies:

cd examples
pip install -r requirements.txt

Download the MNLI dataset:

cd pruning
python ../../utils/download_glue_data.py --data_dir data/ --task MNLI

Joint Pruning

Joint DSP

export TASK_NAME=MNLI
python run_dsp.py \
    --data_dir data/$TASK_NAME/ \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --max_seq_length 128 \
    --per_device_train_batch_size 32 \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --output_dir dsp_output/ \
    --cache_dir cache/ \
    --save_steps 10000 \
    --num_of_heads 12 \
    --pruning_lr 0.5 \
    --annealing \
    --initial_temperature 1000 \
    --final_temperature 1e-8 \
    --cooldown_steps 25000 \
    --joint_pruning 

where num_of_heads is the number of attention heads you want to keep unpruned, pruning_lr, initial_temperature, final_temperature, and cooldown_steps are the hyperparameters for DSP.

You can also use straight-through estimator:

export TASK_NAME=MNLI
python run_dsp.py \
    --data_dir data/$TASK_NAME/ \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --max_seq_length 128 \
    --per_device_train_batch_size 32 \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --output_dir dsp_output/ \
    --cache_dir cache/ \
    --save_steps 10000 \
    --num_of_heads 12 \
    --pruning_lr 0.5 \
        --use_ste \
    --joint_pruning 

Voita et al.

The baseline of Voita et al., 2019 can be reproduced by

export TASK_NAME=MNLI
python run_voita.py \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --data_dir data/$TASK_NAME/ \
    --max_seq_length 128 \
    --per_device_train_batch_size 32 \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --output_dir voita_output/\
    --cache_dir cache/ \
    --save_steps 10000 \
    --l0_penalty 0.0015 \
    --joint_pruning \
    --pruning_lr 0.1 

where l0_penalty is used to indirectly control the number of heads.

Pipelined Pruning

We need to fine tune BERT on MNLI first:

export TASK_NAME=MNLI
python ../text-classification/run_glue.py \
    --model_name_or_path bert-base-uncased \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --data_dir data/$TASK_NAME/ \
    --max_seq_length 128 \
    --per_device_train_batch_size 32 \
    --learning_rate 2e-5 \
    --num_train_epochs 3.0 \
    --output_dir output/ \
    --cache_dir cache/ \
    --save_steps 10000 

Pipelined DSP

export TASK_NAME=MNLI
python run_dsp.py \
    --data_dir data/$TASK_NAME/ \
    --model_name_or_path output/ \
    --task_name $TASK_NAME \
    --do_train \
    --do_eval \
    --max_seq_length 128 \
    --per_device_train_batch_size 32 \
    --learning_rate 2e-5 \
    --num_train_epochs 1.0 \
    --output_dir dsp_output/pipelined/ \
    --cache_dir cache/ \
    --save_steps 10000 \
    --num_of_heads 12 \
    --pruning_lr 0.5 \
    --annealing \
    --initial_temperature 1000 \
    --final_temperature 1e-8 \
    --cooldown_steps 8333 

Michel et al.

The baseline of Michel et al., 2019 can be reproduced by

export TASK_NAME=MNLI
python run_michel.py \
    --data_dir data/$TASK_NAME/ \
    --model_name_or_path output/ \
    --task_name $TASK_NAME \
    --output_dir michel_output/ \
    --cache_dir cache/ \
    --tokenizer_name bert-base-uncased \
    --exact_pruning \
        --num_of_heads 12 

where num_of_heads is the number of heads to prune (mask) at each step.

Enc-Dec on IWSLT

Install the fairseq library adapted from Fairseq:

cd fairseq
pip install -e .

Download and prepare the IWSLT dataset:

cd examples/pruning
./prepare-iwslt.sh

Preprocess the dataset:

export TEXT=iwslt14.tokenized.de-en
fairseq-preprocess --source-lang de --target-lang en \
    --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
    --destdir data-bin/\
    --workers 20

Joint Training

Joint DSP

export SAVE_DIR=dsp_checkpoints/

python run_dsp.py \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --max-epoch 60 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \
    --save-dir $SAVE_DIR/ \
    --num-of-heads 4 \
    --pruning-lr 0.2 \
    --joint-pruning \
    --annealing \
    --initial-temperature 0.1 \
    --final-temperature 1e-8 \
    --cooldown-steps 15000 

python generate_dsp.py \
    data-bin \
    -s de -t en \
    --path $SAVE_DIR/checkpoint_best.pt \
    --quiet \
    --batch-size 32 --beam 5 --remove-bpe

where num_of_heads is the number of attention heads you want to keep unpruned, pruning-lr, initial-temperature, final-temperature, and cooldown-steps are the hyperparameters for DSP.

You can also use straight-through estimator:

export SAVE_DIR=dsp_checkpoints/

python run_dsp.py \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --max-epoch 60 \
    --eval-bleu \
    --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \
    --eval-bleu-detok moses \
    --eval-bleu-remove-bpe \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric \
    --save-dir $SAVE_DIR/ \
    --num-of-heads 4 \
    --pruning-lr 0.2 \
    --joint-pruning \
    --use-ste 

python generate_dsp.py \
    data-bin \
    -s de -t en \
    --path $SAVE_DIR/checkpoint_best.pt \
    --quiet \
    --batch-size 32 --beam 5 --remove-bpe

Voita et al.

The baseline of Voita et al., 2019 can be reproduced by

export SAVE_DIR=voita_checkpoints/

python run_voita.py \
        data-bin/ \
        --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
        --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
        --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
        --dropout 0.3 --weight-decay 0.0001 \
        --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
        --max-tokens 4096 \
        --max-epoch 60 \
        --keep-best-checkpoints 1 \
        --l0-penalty 200 \
        --save-dir $SAVE_DIR/ 

python generate_voita.py \
    data-bin \
    -s de -t en \
    --path $SAVE_DIR/checkpoint_last.pt \
    --quiet \
    --batch-size 32 --beam 5 --remove-bpe

where l0_penalty is used to indirectly control the number of heads.

Pipelined Pruning

We need to fine tune the model on IWSLT first:

fairseq-train \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --max-epoch 60 \
    --save-dir checkpoints/ 

Pipelined DSP

export SAVE_DIR=dsp_checkpoints/pipelined/

python run_dsp.py \
    data-bin/ \
    --arch transformer_iwslt_de_en --share-decoder-input-output-embed \
    --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
    --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
    --dropout 0.3 --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
    --max-tokens 4096 \
    --reset-optimizer \
    --restore-file checkpoints/checkpoint_last.pt \
    --max-epoch 61 \
    --save-dir $SAVE_DIR/ \
    --num-of-heads 4 \
    --pruning-lr 0.2 \
    --annealing \
    --initial-temperature 0.1 \
    --final-temperature 1e-8 \
    --cooldown-steps 250 

python generate_dsp.py \
    data-bin \
    -s de -t en \
    --path $SAVE_DIR/checkpoint_last.pt \
    --quiet \
    --batch-size 32 --beam 5 --remove-bpe

Michel et al.

The baseline of Michel et al., 2019 can be reproduced by

python run_michel.py \
    data-bin \
    -s de -t en \
    -a transformer_iwslt_de_en \
    --restore-file checkpoints/checkpoint_last.pt \
    --optimizer adam --adam-betas '(0.9,0.98)' \
    --reset-optimizer \
    --batch-size 32 --beam 5 --remove-bpe \
    --num-of-heads 4 \
    --exact-pruning 

where num_of_heads is the number of heads to prune (mask) at each step.