frankaging / Causal-Distill

The Codebase for Causal Distillation for Language Models (NAACL '22)
MIT License
25 stars 3 forks source link

Pre-Training questions #1

Closed stefan-it closed 2 years ago

stefan-it commented 2 years ago

Hi @frankaging ,

thanks for releasing the code and paper for your causal distillation approach :hugs:

I have some basic questions regarding the distillation process:

I would like to train new distilled models for some of my previous trained models (such as DistilBERTurk or German DistilBERT), so I would like to know, which implementation I could use it. E.g. there's a current dev branch, which has some more recent changes.

In the current readme, the following command is used for causal distillation:

CUDA_VISIBLE_DEVICES=9,4 python causal_train.py \
--force \
--n_gpu 2 \
--is_wandb \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased-small.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_3.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle.nm \
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal 0.25 \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-15M/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-15M/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--gradient_accumulation_steps 50 \
--n_epoch 3 \
--batch_size 5

This raised a few questions: it seems that only 2 GPUs are used, whereas the paper mentions 4 TITAN GPUs. The total batch size per device is 50 x 5 = 250, so the total batch size used for training is 2 x 250 = 500. Could you please specify the hyperparams you've used for the paper model :thinking:

Did you perform some experiments with using fp16?

Many thanks in advance!

frankaging commented 2 years ago

Thanks for raising the question.

This training code in readme needs a proper update! Here is one example of training code we use currently:

python causal_train.py \
--force \
--n_gpu 4 \
--log_interval 10 \
--student_type distilbert \
--student_config ./training_configs/distilbert-base-uncased.json \
--student_pretrained_weights ./distillation_checkpoints/bert-base-uncased_num_layer_6.pth \
--teacher_type bert \
--teacher_name bert-base-uncased \
--neuron_mapping ./training_configs/single_middle_layer_6.nm \
--mlm --alpha_ce 0.25 --alpha_mlm 0.25 --alpha_cos 0.25 --alpha_clm 0.0 --alpha_causal 0.25 \
--interchange_prop 0.3 --interchange_max_token -1 --interchange_consecutive_only \
--freeze_pos_embs \
--dump_path ./results/ \
--data_file ./wikitext-dataset/binarized_text.train.bert-base-uncased.pickle \
--token_counts ./wikitext-dataset/binarized_text.train.token_counts.bert-base-uncased.pickle \
--seed 42 \
--n_epoch 3 \
--gradient_accumulation_steps 6 \
--batch_size 40

As you can see the effective batch size is 6*40 = 240. I modify the distillation code provided from HuggingFace so that the 40 here is not per device; it is for all devices, I believe but correct me if I am wrong (i.e., for each device, it would be 40/4=10). We are about to update this codebase soon (as you can find out, I am working on the dev branch a lot, and I will probably merge at some time this week.). I will also address this confusion.

For now, I would suggest keeping the effective batch size close to 240. If you have more or less GPU power, you can also tune up or down your per device batch size.

No, we are not performing experiments with fp16 now.

stefan-it commented 2 years ago

Hi @frankaging ,

thanks for that :hugs: I was able to train a model with it on my corpus, so I'm closing here :)