alexa / bort

Repository for the paper "Optimal Subarchitecture Extraction for BERT"
Apache License 2.0
469 stars 39 forks source link

Accuracy during fine-tuning is very low (only 0.68) #2

Closed waugustus closed 3 years ago

waugustus commented 3 years ago

Hi, I have tried to finetune the model with the run_finetune.sh script, but the accuracy is very low.

Here is the log:

INFO:root:18:19:05 Namespace(accumulate=None, batch_size=8, dataset='openwebtext_ccnews_stories_books_cased', dev_batch_size=8, dropout=0.1, dtype='float32', early_stop=200, epochs=1, epsilon=1e-06, gpu=1, init='uniform', log_interval=10, lr=5e-05, max_len=512, model_parameters=None, momentum=0.9, multirc_test_location='/home/ec2-user/.mxnet/datasets/superglue_multirc/test.jsonl', no_distributed=False, only_inference=False, output_dir='./output_dir', pretrained_parameters='model/bort.params', prob=0.5, race_dataset_location=None, ramp_up_epochs=1, record_dev_location='/home/ec2-user/.mxnet/datasets/superglue_record/val.jsonl', record_test_location='/home/ec2-user/.mxnet/datasets/superglue_record/test.jsonl', seed=2, task_name='MRPC', training_steps=None, use_scheduler=True, warmup_ratio=0.45, weight_decay=110.0) INFO:root:18:19:05 get_bort_model: bort_4_8_768_1024 INFO:root:18:19:08 loading Bort params from model/bort.params INFO:root:18:19:08 Processing dataset... INFO:root:18:19:11 Now we are doing Bort classification training on gpu(0)! INFO:root:18:19:11 training steps=458 INFO:root:18:19:12 [Epoch 1 Batch 10/465] loss=3.2699, lr=0.0000021845, metrics:f1:0.5835,accuracy:0.4800 INFO:root:18:19:12 [Epoch 1 Batch 20/465] loss=3.5952, lr=0.0000046117, metrics:f1:0.5684,accuracy:0.4774 INFO:root:18:19:13 [Epoch 1 Batch 30/465] loss=3.3002, lr=0.0000070388, metrics:f1:0.6084,accuracy:0.5149 INFO:root:18:19:13 [Epoch 1 Batch 40/465] loss=2.6364, lr=0.0000094660, metrics:f1:0.6187,accuracy:0.5302 INFO:root:18:19:13 [Epoch 1 Batch 50/465] loss=2.9594, lr=0.0000118932, metrics:f1:0.6443,accuracy:0.5494 INFO:root:18:19:14 [Epoch 1 Batch 60/465] loss=2.4208, lr=0.0000143204, metrics:f1:0.6581,accuracy:0.5621 INFO:root:18:19:14 [Epoch 1 Batch 70/465] loss=3.3903, lr=0.0000167476, metrics:f1:0.6549,accuracy:0.5586 INFO:root:18:19:15 [Epoch 1 Batch 80/465] loss=2.5813, lr=0.0000191748, metrics:f1:0.6504,accuracy:0.5606 INFO:root:18:19:15 [Epoch 1 Batch 90/465] loss=2.2408, lr=0.0000216019, metrics:f1:0.6447,accuracy:0.5610 INFO:root:18:19:16 [Epoch 1 Batch 100/465] loss=3.1120, lr=0.0000240291, metrics:f1:0.6551,accuracy:0.5675 INFO:root:18:19:16 [Epoch 1 Batch 110/465] loss=2.4501, lr=0.0000264563, metrics:f1:0.6541,accuracy:0.5647 INFO:root:18:19:16 [Epoch 1 Batch 120/465] loss=2.6082, lr=0.0000288835, metrics:f1:0.6571,accuracy:0.5645 INFO:root:18:19:17 [Epoch 1 Batch 130/465] loss=2.4734, lr=0.0000313107, metrics:f1:0.6672,accuracy:0.5741 INFO:root:18:19:17 [Epoch 1 Batch 140/465] loss=2.2288, lr=0.0000337379, metrics:f1:0.6645,accuracy:0.5740 INFO:root:18:19:18 [Epoch 1 Batch 150/465] loss=1.6799, lr=0.0000361650, metrics:f1:0.6641,accuracy:0.5758 INFO:root:18:19:18 [Epoch 1 Batch 160/465] loss=1.1061, lr=0.0000385922, metrics:f1:0.6703,accuracy:0.5810 INFO:root:18:19:18 [Epoch 1 Batch 170/465] loss=1.4413, lr=0.0000410194, metrics:f1:0.6712,accuracy:0.5835 INFO:root:18:19:19 [Epoch 1 Batch 180/465] loss=1.2923, lr=0.0000434466, metrics:f1:0.6684,accuracy:0.5810 INFO:root:18:19:19 [Epoch 1 Batch 190/465] loss=1.9684, lr=0.0000458738, metrics:f1:0.6627,accuracy:0.5780 INFO:root:18:19:20 [Epoch 1 Batch 200/465] loss=1.6337, lr=0.0000483010, metrics:f1:0.6620,accuracy:0.5772 INFO:root:18:19:20 [Epoch 1 Batch 210/465] loss=1.9206, lr=0.0000494048, metrics:f1:0.6632,accuracy:0.5771 INFO:root:18:19:21 [Epoch 1 Batch 220/465] loss=1.5550, lr=0.0000474206, metrics:f1:0.6655,accuracy:0.5782 INFO:root:18:19:21 [Epoch 1 Batch 230/465] loss=1.5174, lr=0.0000454365, metrics:f1:0.6647,accuracy:0.5760 INFO:root:18:19:21 [Epoch 1 Batch 240/465] loss=1.6342, lr=0.0000434524, metrics:f1:0.6563,accuracy:0.5698 INFO:root:18:19:22 [Epoch 1 Batch 250/465] loss=1.6304, lr=0.0000414683, metrics:f1:0.6521,accuracy:0.5669 INFO:root:18:19:22 [Epoch 1 Batch 260/465] loss=1.5732, lr=0.0000394841, metrics:f1:0.6523,accuracy:0.5681 INFO:root:18:19:23 [Epoch 1 Batch 270/465] loss=0.9988, lr=0.0000375000, metrics:f1:0.6479,accuracy:0.5661 INFO:root:18:19:23 [Epoch 1 Batch 280/465] loss=1.8495, lr=0.0000355159, metrics:f1:0.6485,accuracy:0.5673 INFO:root:18:19:23 [Epoch 1 Batch 290/465] loss=1.0105, lr=0.0000335317, metrics:f1:0.6523,accuracy:0.5702 INFO:root:18:19:24 [Epoch 1 Batch 300/465] loss=0.8022, lr=0.0000315476, metrics:f1:0.6535,accuracy:0.5708 INFO:root:18:19:24 [Epoch 1 Batch 310/465] loss=0.8974, lr=0.0000295635, metrics:f1:0.6546,accuracy:0.5713 INFO:root:18:19:25 [Epoch 1 Batch 320/465] loss=0.9764, lr=0.0000275794, metrics:f1:0.6527,accuracy:0.5698 INFO:root:18:19:25 [Epoch 1 Batch 330/465] loss=0.8853, lr=0.0000255952, metrics:f1:0.6521,accuracy:0.5692 INFO:root:18:19:25 [Epoch 1 Batch 340/465] loss=0.9318, lr=0.0000236111, metrics:f1:0.6521,accuracy:0.5687 INFO:root:18:19:26 [Epoch 1 Batch 350/465] loss=0.9023, lr=0.0000216270, metrics:f1:0.6548,accuracy:0.5702 INFO:root:18:19:26 [Epoch 1 Batch 360/465] loss=0.8698, lr=0.0000196429, metrics:f1:0.6545,accuracy:0.5697 INFO:root:18:19:27 [Epoch 1 Batch 370/465] loss=0.9013, lr=0.0000176587, metrics:f1:0.6552,accuracy:0.5698 INFO:root:18:19:27 [Epoch 1 Batch 380/465] loss=0.8277, lr=0.0000156746, metrics:f1:0.6550,accuracy:0.5698 INFO:root:18:19:28 [Epoch 1 Batch 390/465] loss=0.7523, lr=0.0000136905, metrics:f1:0.6591,accuracy:0.5732 INFO:root:18:19:28 [Epoch 1 Batch 400/465] loss=0.8378, lr=0.0000117063, metrics:f1:0.6621,accuracy:0.5759 INFO:root:18:19:28 [Epoch 1 Batch 410/465] loss=0.8365, lr=0.0000097222, metrics:f1:0.6633,accuracy:0.5783 INFO:root:18:19:29 [Epoch 1 Batch 420/465] loss=0.8266, lr=0.0000077381, metrics:f1:0.6610,accuracy:0.5779 INFO:root:18:19:29 [Epoch 1 Batch 430/465] loss=0.7012, lr=0.0000057540, metrics:f1:0.6627,accuracy:0.5794 INFO:root:18:19:30 [Epoch 1 Batch 440/465] loss=0.8099, lr=0.0000037698, metrics:f1:0.6640,accuracy:0.5798 INFO:root:18:19:30 [Epoch 1 Batch 450/465] loss=0.8413, lr=0.0000017857, metrics:f1:0.6645,accuracy:0.5798 INFO:root:18:19:30 [Epoch 1 Batch 460/465] loss=0.8424, lr=0.0000001000, metrics:f1:0.6651,accuracy:0.5799 INFO:root:18:19:31 Now we are doing evaluation on dev with gpu(0). INFO:root:18:19:31 [Batch 10/51] loss=0.6030, metrics:f1:0.8086,accuracy:0.7000 INFO:root:18:19:31 [Batch 20/51] loss=0.6176, metrics:f1:0.8028,accuracy:0.6875 INFO:root:18:19:31 [Batch 30/51] loss=0.5927, metrics:f1:0.8101,accuracy:0.6958 INFO:root:18:19:31 [Batch 40/51] loss=0.6228, metrics:f1:0.8073,accuracy:0.6906 INFO:root:18:19:31 [Batch 50/51] loss=0.6260, metrics:f1:0.8025,accuracy:0.6850 INFO:root:18:19:31 epoch: 0; validation metrics:f1:0.8019,accuracy:0.6838 INFO:root:18:19:31 Time cost=0.61s, throughput=670.77 samples/s INFO:root:18:19:32 params saved in: ./output_dir/model_bort_MRPC_0.params INFO:root:18:19:32 Time cost=20.66s INFO:root:18:19:32 Best model at epoch 0. Validation metrics:f1:0.8019,accuracy:0.6838,average:0.7429 INFO:root:18:19:32 Now we are doing testing on test with gpu(0). INFO:root:18:19:34 Time cost=2.26s, throughput=764.32 samples/s

adewynter commented 3 years ago

Hi!

Yes, that behavior is (sort of) expected. The downside of such a highly optimized model is that it is very small, and conventional fine-tuning won't work that well. Indeed, looks like you only ran it for a few (one) epochs. Normally Bort takes anywhere between 50 to 100 epochs for "small" problems, and up to 1000 for large ones. I'd just ignore the large ones if you are just trying out--MRPC is a great choice for this!

Now, we mention in the README.md file that it is highly recommended that you implement Agora, the fine-tuning algorithm. Unfortunately, we have no plans at the moment of releasing that part of the code. BUT it's on one of the papers (the second) linked there on the README file.