facebookresearch / ParlAI

A framework for training and evaluating AI models on a variety of openly available dialogue datasets.
https://parl.ai
MIT License
10.49k stars 2.1k forks source link

Cannot Reproduce Results in the “Beyond Goldfish Memory: ∗ Long-Term Open-Domain Conversation” paper #4534

Closed lorafei closed 2 years ago

lorafei commented 2 years ago

Hi!

I am reproducing MSC 2.7B (RAG) results in the “Beyond Goldfish Memory: ∗ Long-Term Open-Domain Conversation” paper. In the first place, I used the same hyperparameters as the released checkpoint _zoo:msc/summscrag3B/model, except for using the summsc, I used the original msc. However, I cannot reproduce the same performance in the paper of the MSC 2.7B (RAG) model.

/home/sysadmin/fei/ParlAI/parlai/core/torch_generator_agent.py:1754: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This
 results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  hyp_ids = best_idxs // voc_size
loading: data/msc/msc/msc_dialogue/session_2
loading: data/msc/msc/msc_dialogue/session_3
loading: data/msc/msc/msc_dialogue/session_4
05:33:12 | eval completed in 71447.01s
05:33:12 | valid:
                     accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps   exs    f1  gen_n_toks  gpu_mem  llen  loss       lr  ltpb  ltps   ltrunc  ltrunclen  \
   all                      0 .006831 661.2  4579 204.7   .2335      56.02 .3568 25492 .1925       24.32   .08567 27.47 2.264 9.95e-07 211.4 9.452 .0003391     .00564
   msc:Session1Self         0   .0082 185.6                   0          0        7801 .1846        22.8          14.41 2.214                             0          0
   msc_dialogue_2           0 .006917 426.1                   0          0        5897 .1967       24.66          31.52 2.288                      .0006783     .01543
   msc_dialogue_3           0 .006388 818.8               .1278       15.5        5890 .1957       24.82          31.57 2.272                      .0001698    .005433
   msc_dialogue_4           0 .005818  1214               .8061      208.6        5904 .1930          25          32.38 2.283                      .0005081    .001694
                      ppl  token_acc  token_em  total_train_updates  tpb   tps
   all              9.628      .4776  .0001282                 8748 4790 214.2
   msc:Session1Self  9.15      .4926  .0005128
   msc_dialogue_2   9.855      .4732         0
   msc_dialogue_3   9.701      .4733         0
   msc_dialogue_4   9.804      .4712         0

05:33:12 | creating task(s): msc
05:33:12 | WARNING: Test set not included. Setting datatype to valid.
05:33:12 | running eval: test
05:33:12 | loading normalized fbdialog data: data/ConvAI2/valid_self_original.txt
05:33:12 | loading fbdialog data: data/ConvAI2/valid_self_original.txt
/home/sysadmin/fei/ParlAI/parlai/core/torch_generator_agent.py:1754: UserWarning: __floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This
 results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').
  hyp_ids = best_idxs // voc_size
loading: data/msc/msc/msc_dialogue/session_2
loading: data/msc/msc/msc_dialogue/session_3
loading: data/msc/msc/msc_dialogue/session_4
01:32:30 | eval completed in 71957.31s
01:32:30 | test:
                     accuracy  bleu-4  clen  ctpb  ctps  ctrunc  ctrunclen  exps   exs    f1  gen_n_toks  gpu_mem  llen  loss       lr  ltpb  ltps   ltrunc  ltrunclen  \
   all                      0 .006415 644.2  4498 200.3   .2246      51.51 .3558 25604 .1922       24.45   .08574 28.19 2.246 9.95e-07 217.2 9.674 .0003369    .003158
   msc:Session1Self         0 .008226 185.6                   0          0        7801 .1844        22.8          14.41 2.214                             0          0
   msc_dialogue_2           0 .006193 396.2                   0          0        5939 .1970       24.74          30.93 2.238                      .0005051    .005051
   msc_dialogue_3           0 .006103 790.1              .08406       8.65        5924 .1953       25.07          32.88 2.257                      .0003376    .002194
   msc_dialogue_4           0 .005137  1205               .8145      197.4        5940 .1922        25.2          34.55 2.276                      .0005051    .005387
                      ppl  token_acc  token_em  total_train_updates  tpb  tps
   all              9.454      .4814  .0001704                 8748 4715  210
   msc:Session1Self 9.151      .4925  .0005128
   msc_dialogue_2   9.374      .4829         0
   msc_dialogue_3   9.555      .4760  .0001688
   msc_dialogue_4   9.734      .4744         0

And the results in the paper is

image

My training script is

MODEL_FILE=/home/sysadmin/fei/ParlAI/log/msc/BlenderBot2RagAgent_param_same_as_summsc_rag3B
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 parlai train_model -dp data \
--model projects.msc.agents.memory_agent:MemoryLongRagAgent \
-mf ${MODEL_FILE}/model --task msc --num_epochs 5 --include_last_session False \
--generation-model transformer_variant/generator \
--memory-decoder-model-file "" --memory-key full_text \
--save-every-n-secs 1800 --log_every_n_secs 60 --datatype train:stream  \
--embeddings-scale True --variant prelayernorm --split-lines false --learn-positional-embeddings false \
--dict-tokenizer bytelevelbpe --dict-file zoo:blender/blender_3B/model.dict \
--bpe-vocab /home/sysadmin/fei/ParlAI/data/models/blender/blender_3B/model.dict-vocab.json \
--bpe-merge /home/sysadmin/fei/ParlAI/data/models/blender/blender_3B/model.dict-merges.txt \
--query-model bert_from_parlai_rag --memory_reader_model bert_from_parlai_rag \
--rag-model-type token --previous-persona-type raw_history \
--dpr-model-file zoo:hallucination/bart_rag_token/model \
--gold-document-titles-key select-docs-titles --insert-gold-docs True \
--beam-min-length 20 --beam-context-block-ngram 3 --beam-block-ngram 3 --beam-block-full-context False --beam-size 10 \
--inference beam --optimizer mem_eff_adam --learningrate 1e-06 --lr-scheduler-patience 3 --model-parallel True \
--knowledge-access-method memory_only --batchsize 8 --dropout 0.1 --attention-dropout 0.0 \
--min-doc-token-length 64 --max-doc-token-length 128 \
--fp16 True --fp16-impl mem_efficient --force-fp16-tokens false \
--tensorboard-log true --tensorboard-logdir ${MODEL_FILE}/logs/tensorboard \
--init_BST3B_model /home/sysadmin/fei/ParlAI/data/models/blender/blender_3B/model -o /home/sysadmin/fei/ParlAI/parlai/opt_presets/arch/blenderbot_3B --lr_scheduler reduceonplateau \
--validation-metric-mode min --validation-every-n-epochs 0.25 --validation-max-exs 1000 --validation-metric ppl --validation-patience 10 --validation_max_exs 20000  \
--dynamic-batching full --truncate 1024 --eval_dynamic_batching off \
--memory_writer_model_file zoo:hallucination/multiset_dpr/hf_bert_base.cp \
--max_train_steps 10000 -n_positions 1024 --n-positions-init 128 --n-docs 6 \
--text-truncate 1024 --label-truncate 128 --history_add_global_end_token end --warmup_updates 200 \
--max-memories 80
jxmsML commented 2 years ago

Hi, note that the MSC-RAG model has a few hyperparameters, not exactly the same as your train cmd

  1. we split the task -t msc to --t msc:Session1Self:is_convai2_session_level=True,msc,msc:SessionBaseMsc:session_id=2,msc:SessionBaseMsc:session_id=3,msc:SessionBaseMsc:session_id=4
  2. --min-doc-token-length 128 instead of 64 --max-doc-token-leng 256 instead of 128 as it needs to deal with raw history;
  3. --retriever-ignore-phrases persona:,__his__, --memory-extractor-phrase persona:,__his__ , --memory-delimiter <SPECIALTOKEN> --previous-session-delimiter <SPECIALTOKEN> to make sure retrieving correct memory documents on session level

Please let me know if you can reproduce similar numbers

github-actions[bot] commented 2 years ago

This issue has not had activity in 30 days. Please feel free to reopen if you have more issues. You may apply the "never-stale" tag to prevent this from happening.