kiddyboots216 / CommEfficient

PyTorch for benchmarking communication-efficient distributed SGD optimization algorithms
71 stars 20 forks source link

Help with the PersonaChat experiment command #3

Closed greeneggsandyaml closed 3 years ago

greeneggsandyaml commented 3 years ago

Hello authors, thank you for your great paper -- I'll be citing it in my upcoming work.

I am trying to run your code, specifically the PersonaChat experiments. Would you be able to provide the exact set of commands used to generate the results in Figure 5 (i.e. ~15 ppl)? I have read the paper and all of your code (you did a lot of work on this!), and I wanted to confirm the command for the run. It would be extremely helpful if you were to share your commands.

At the moment I am simply trying to replicating the uncompressed results. I now have:

 python gpt2_train.py \
--dataset_name PERSONA \
--local_momentum 0 \
--dataset_dir ./dataset/personachat \
--mode uncompressed \
--seed 42 \
--local_batch_size -1 \
--num_results_train 1 --num_results_val 2 \
--num_epochs 1 \
--valid_batch_size 4 \
--port 6239 \
--num_workers 4 --num_devices 4

Is this right? Have I missed a hyperparameter that might be important?

I tried changing the learning rate to 0.16 with --lr_scale 0.16, but I quickly get NaNs. Should I set lm_coef or mc_coef to something other than the default? The results I'm getting are quite different (worse) than the paper, so I'm trying to find out what is the discrepancy.

Thank you so much for your help! I really do appreciate it.

kiddyboots216 commented 3 years ago

Here's my training script, maybe this can help. Let me know if you have any questions.

python gpt2_train.py \
    --dataset_dir /data/ashwineep/datasets/persona_chat/ \
    --dataset_name PERSONA \
    --model_checkpoint gpt2 \
    --num_results_train 1 \
    --num_results_val 2 \
    --lm_coef=2.0 \
    --max_history=2 \
    --num_candidates=4 \
    --personality_permutations=2 \
    --valid_batch_size 8 \
    --train_dataloader_workers 4 \
    --val_dataloader_workers 4 \
    --num_devices 5 \
    --microbatch_size 4 \
    --mode $1 \
    --error_type $2 \
    --lr_scale $3 \
    --num_epochs=$4 \
    $5 \
    --num_workers $6 \
    --local_batch_size $7 \
    --k $8 \
    --num_rows $9 \
    --num_cols ${10} \
    --local_momentum ${11} \
    --virtual_momentum ${12} \
    --max_grad_norm ${13} \
    --num_fedavg_epochs ${14} \
    --fedavg_batch_size ${15} \
    --port ${16} \
    --seed ${17} \
greeneggsandyaml commented 3 years ago

Thank you for the quick reply!

Do you know what args you passed to that command (because that command still has unspecified arguments) to get the uncompressed results?

I tried running that exact command with --mode uncompressed --error_type none --num_epochs 1 --num_workers 4 --local_batch_size -1 --local_momentum 0, but I believe it is not working as well as it should. For clarity, my full command is:

python gpt2_train.py \
    --dataset_dir ./dataset/personachat \
    --dataset_name PERSONA \
    --model_checkpoint gpt2 \
    --num_results_train 1 \
    --num_results_val 2 \
    --lm_coef 2.0 \
    --max_history 2 \
    --num_candidates 4 \
    --personality_permutations 2 \
    --valid_batch_size 8 \
    --train_dataloader_workers 4 \
    --val_dataloader_workers 4 \
    --num_devices 5 \
    --microbatch_size 4 \
    --mode uncompressed \
    --error_type none \
    --num_epochs 1 \
    --num_workers 4 \
    --lr_scale 2e-5 \
    --local_batch_size -1 \
    --local_momentum 0

I'm also unsure of the lr_scale: the default (4e-2) diverges, so I assume I should use a lower learning rates (e.g. 2e-5). When I do, I am able to train but do not achieve the expected performance.

Thank you again!

kiddyboots216 commented 3 years ago

lr_scale of 0.16 should work for gpt2. local momentum 0, virtual momentum 0.9, max grad norm 1.

greeneggsandyaml commented 3 years ago

Okay thanks -- I will try it right now and get back to you when it finishes running!

greeneggsandyaml commented 3 years ago

Thank you for the help! I was able to reproduce the uncompressed results :) I will ask again if I have additional questions or trouble with future experiments.

I really appreciate your responsiveness on this thread!