lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.1k stars 254 forks source link

Eval loss not consistent in multiple iterations #82

Closed singhay closed 4 years ago

singhay commented 4 years ago

Hi,

I'm saving multiple checkpoints when training model and saving final model as well.

When evaluating I do model.eval() along with torch.no_grad(). This is where is gets weird: If I evaluate all my checkpoints back to back, the loss is different Vs just evaluating just one model.

Here's how to run the script

python -m train_reformer --do_train --train_data_file ./data.json --do_eval --eval_data_file ./data.json --tokenizer_path ./data/tokenizer/60k --num_train_epochs 80 --output_dir ./output/lm --per_gpu_train_batch_size 2 --gradient_accumulation_steps 1 --overwrite_output_dir --logging_steps 10 --save_steps 20 --learning_rate 1e-2 --evaluate_during_training --per_gpu_eval_batch_size 1 --eval_all_checkpoints

Just eval: remove --do_train flag To only run single eval remove --eval_all_checkpoints

Single eval

04/11/2020 14:43:21 - INFO - __main__ -   ***** Eval results  *****
04/11/2020 14:43:21 - INFO - __main__ -     loss = 3.2982592284679413
04/11/2020 14:43:21 - INFO - __main__ -     perplexity = 27.06548309326172

Multiple eval: where all checkpoints are evaluated and the last one should be equal to above but it's not

04/11/2020 15:07:22 - INFO - __main__ -   Evaluate the following checkpoints: ['./output/lm/checkpoint-40', './output/lm/checkpoint-60', './output/lm/checkpoint-80', './output/lm']
04/11/2020 15:07:22 - INFO - __main__ -   Evaluate the following checkpoint: ./output/lm/checkpoint-40
04/11/2020 15:07:23 - INFO - __main__ -   ***** Running evaluation checkpoint-40 *****
04/11/2020 15:07:23 - INFO - __main__ -     Num examples = 2
04/11/2020 15:07:23 - INFO - __main__ -     Batch size = 1
Evaluating: 100%|██████████| 2/2 [00:02<00:00,  1.02s/it]
04/11/2020 15:07:25 - INFO - __main__ -   ***** Eval results checkpoint-40 *****
04/11/2020 15:07:25 - INFO - __main__ -     loss = 6.260987758636475
04/11/2020 15:07:25 - INFO - __main__ -     perplexity = 523.7360229492188
04/11/2020 15:07:25 - INFO - __main__ -   Evaluate the following checkpoint: ./output/lm/checkpoint-60
04/11/2020 15:07:26 - INFO - __main__ -   ***** Running evaluation checkpoint-60 *****
04/11/2020 15:07:26 - INFO - __main__ -     Num examples = 2
04/11/2020 15:07:26 - INFO - __main__ -     Batch size = 1
Evaluating: 100%|██████████| 2/2 [00:01<00:00,  1.01it/s]
04/11/2020 15:07:28 - INFO - __main__ -   ***** Eval results checkpoint-60 *****
04/11/2020 15:07:28 - INFO - __main__ -     loss = 2.44088876247406
04/11/2020 15:07:28 - INFO - __main__ -     perplexity = 11.483243942260742
04/11/2020 15:07:28 - INFO - __main__ -   Evaluate the following checkpoint: ./output/lm/checkpoint-80
04/11/2020 15:07:29 - INFO - __main__ -   ***** Running evaluation checkpoint-80 *****
04/11/2020 15:07:29 - INFO - __main__ -     Num examples = 2
04/11/2020 15:07:29 - INFO - __main__ -     Batch size = 1
Evaluating: 100%|██████████| 2/2 [00:02<00:00,  1.05s/it]
04/11/2020 15:07:31 - INFO - __main__ -   ***** Eval results checkpoint-80 *****
04/11/2020 15:07:31 - INFO - __main__ -     loss = 1.8967646956443787
04/11/2020 15:07:31 - INFO - __main__ -     perplexity = 6.664299011230469
04/11/2020 15:07:31 - INFO - __main__ -   Evaluate the following checkpoint: ./output/lm
04/11/2020 15:07:33 - INFO - __main__ -   ***** Running evaluation  *****
04/11/2020 15:07:33 - INFO - __main__ -     Num examples = 2
04/11/2020 15:07:33 - INFO - __main__ -     Batch size = 1
Evaluating: 100%|██████████| 2/2 [00:02<00:00,  1.01s/it]
04/11/2020 15:07:35 - INFO - __main__ -   ***** Eval results  *****
04/11/2020 15:07:35 - INFO - __main__ -     loss = 1.1265813708305359
04/11/2020 15:07:35 - INFO - __main__ -     perplexity = 3.0850918292999268

Here's my toy dataset of two records

{"tokens":[["Wednesday",",","August 21, 2019","11:27 AM"],["Clinical","Certification"]]}
{"tokens":[["Reason","For","Exam","low","bakc","pain","with","radicular","symptoms","(","worsening",")",";","Radiculopathy",",",">","6","wks","conservative","tx",",","persistent","sx","Order","Information"]]}

source gist

lucidrains commented 4 years ago

@singhay I think this would be a problem with the script and not the library. You should just console out log the path for the desired checkpoint, and reconcile that with the single eval