facebookresearch / ParlAI

A framework for training and evaluating AI models on a variety of openly available dialogue datasets.
https://parl.ai
MIT License
10.46k stars 2.09k forks source link

BB2 perplexity (PPL) scores on validation sets #4821

Closed mailong25 closed 1 year ago

mailong25 commented 1 year ago

I trying to further fine-tune the given BB2_3B model on MSC + WOI datasets. I'm expect that the PPL scores on validation set won't change much after the fine-tuning. Surprisingly, I got a big PPL improvement on CONVAI2 task while the PPL scores for other tasks remain the same. . I realized that the PPL score of the original BB2 on CONVAI2 task is really bad compared to other tasks. Here is the evaluation results of the original BB2 model on validation set:

                               ppl  token_acc  token_em  total_train_updates  \
   WizInternetWizardTeacher    11.13      .4704         0                        
   msc:Session1Self            13.28      .4292         0                        
   msc_dialogue_2              9.315      .4765         0                        
   msc_dialogue_3              9.665      .4745         0                        
   msc_dialogue_4              9.462      .4746         0                        

Here is the evaluation results after fine-tuning for 3000 steps (total: 24000 samples with 4000 samples/task)

                                ppl  token_acc  token_em  total_train_updates  \
   WizInternetWizardTeacher    11.03      .4864         0                        
   msc:Session1Self            10.72      .4734         0                        
   msc_dialogue_2              9.263      .4791         0                        
   msc_dialogue_3              9.614      .4753         0                        
   msc_dialogue_4              9.443      .4747         0       

As you can see, there is a big PPL improvement in msc:Session1Self task (CONVAI2) . I not really sure what happening there. Maybe the original BB2 model is not trained with CONVAI-2 samples? Or the BB2 model may not converge yet when the training end? . Here is my fine-tuning command:

parlai train_model -dp ../data \
--model projects.blenderbot2.agents.blenderbot2:BlenderBot2FidAgent \
--num_epochs 50 \
--query-model bert_from_parlai_rag --generation-model bart \
--rag-model-type token --rag-retriever-type search_engine --search_server None \
--dpr-model-file zoo:hallucination/bart_rag_token/model \
--gold-document-titles-key __select-docs-titles__ --insert-gold-docs True --model-parallel False \
--inference beam --learningrate 5e-06 \
--memory-key personas --memory-decoder-beam-min-length 3 \
--search-query-generator-model-file zoo:blenderbot2/query_generator/model --search-query-generator-beam-min-length 2 \
--knowledge-access-method memory_only \
--fp16 True --fp16-impl mem_efficient --force-fp16-tokens True --lr-scheduler-patience 100 \
--embedding-size 2560 --ffn-size 10240 --dropout 0.0 --attention-dropout 0.0 --n-heads 32 --learn-positional-embeddings False \
--embeddings-scale True --n-positions 128 --variant prelayernorm --activation relu --n-encoder-layers 2 --n-decoder-layers 24 \
--generation-model transformer/generator --beam-size 10 --beam-min-length 20 --beam-context-block-ngram 4 --beam-block-ngram 4 \
--history-add-global-end-token end --dict-tokenizer bytelevelbpe --dict-file ../data/models/blenderbot2/blenderbot2_3B/model.dict \
--bpe-vocab ../data/models/blenderbot2/blenderbot2_3B/model.dict-vocab.json --bpe-merge ../data/models/blenderbot2/blenderbot2_3B/model.dict-merges.txt \
--beam-block-full-context False --warmup-updates 100 --skip-generation True --checkpoint-activations True \
--truncate 128 --text-truncate 128 --label-truncate 128 --min-doc-token-length 64 --max-doc-token-length 64 --validation-metric ppl \
--n-docs 30 --n-ranked-doc-chunks 2 --splitted-chunk-length 64 --doc-chunks-ranker head --save-every-n-secs 480000 --validation_every_n_secs 720000 --log-every-n-steps 180000 \
--optimizer mem_eff_adam --task  msc:Session1Self:is_convai2_session_level=False,msc:SessionBaseMsc:session_id=2,msc:SessionBaseMsc:session_id=3,msc:SessionBaseMsc:session_id=4,wizard_of_internet \
--multitask-weights stochastic --datatype train:stream  \
--update-freq 1 --save-after-valid True --validation_every_n_steps 3000 --log-every-n-secs 100000 --batchsize 8 --init-model ../data/models/blenderbot2/blenderbot2_3B/model  --memory-decoder-model-file ''
mojtaba-komeili commented 1 year ago

Thanks a lot for your work to better fine-tune BB2. It is interesting to see the extra improvement you achieved on the main tasks in BB2. I don't have a firm answer for the question here, other than the hyper parameters you used ended up helping the model to fine-tune better.

klshuster commented 1 year ago

The released BB2 model is also fine-tuned on safety datasets to incorporate baked-in safety measures. Fine-tuning strictly on MSC and WoI will bring it more towards that distribution

github-actions[bot] commented 1 year ago

This issue has not had activity in 30 days. Please feel free to reopen if you have more issues. You may apply the "never-stale" tag to prevent this from happening.