ZhiGroup / Med-BERT

Med-BERT, contextualized embedding model for structured EHR data
Apache License 2.0
244 stars 62 forks source link

Low accuracy of Mask LM problem #14

Closed zhuzitong closed 1 year ago

zhuzitong commented 1 year ago

Hi there,

I was attempting to implement Med-Bert using real data collected from the local hospital system, but encountered an issue: the accuracy of the Masked Language Modeling (MLM) task in the pre-training phase was very low. Have you and your team encountered this problem? Here are the details:

Here is the data's information. vocabulary size: 37214 data size: 25000

Here is the information of config file

{
    "vocab_size": 37214, 
    "hidden_size": 192,
    "num_hidden_layers":6,
    "num_attention_heads": 6,
    "intermediate_size": 64,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "attention_probs_dropout_prob": 0.1,
    "max_position_embeddings": 512,
    "type_vocab_size": 1000,
    "initializer_range": 0.02
}

Here is the information of the model

"max_seq_length", 128,
"max_predictions_per_seq", 4,
"train_batch_size", 32,
"eval_batch_size", 8,
"learning_rate", 5e-5,
"num_train_steps", 5000,
"num_warmup_steps", 1000,
"save_checkpoints_steps", 2000,
"iterations_per_loop", 1000,
"max_eval_steps", 1000,

Here is the result of the pre-training's training part.

I0116 20:09:54.144452 140620173866816 run_EHRpretraining_QA2Seq.py:519] ***** Eval results *****
INFO:tensorflow:    global_step = 5000
I0116 20:09:54.144606 140620173866816 run_EHRpretraining_QA2Seq.py:521]     global_step = 5000
INFO:tensorflow:    loss = 7.54507
I0116 20:09:54.145579 140620173866816 run_EHRpretraining_QA2Seq.py:521]     loss = 7.54507
INFO:tensorflow:    masked_lm_accuracy = 0.032298923
I0116 20:09:54.145787 140620173866816 run_EHRpretraining_QA2Seq.py:521]     masked_lm_accuracy = 0.032298923
INFO:tensorflow:    masked_lm_auc = 0.5747309
I0116 20:09:54.145881 140620173866816 run_EHRpretraining_QA2Seq.py:521]     masked_lm_auc = 0.5747309
INFO:tensorflow:    masked_lm_loss = 7.5184646
I0116 20:09:54.145999 140620173866816 run_EHRpretraining_QA2Seq.py:521]     masked_lm_loss = 7.5184646
INFO:tensorflow:    next_sentence_accuracy = 0.99
I0116 20:09:54.146106 140620173866816 run_EHRpretraining_QA2Seq.py:521]     next_sentence_accuracy = 0.99
INFO:tensorflow:    next_sentence_auc = 0.9859998
I0116 20:09:54.146187 140620173866816 run_EHRpretraining_QA2Seq.py:521]     next_sentence_auc = 0.9859998
INFO:tensorflow:    next_sentence_loss = 0.026682172
I0116 20:09:54.146303 140620173866816 run_EHRpretraining_QA2Seq.py:521]     next_sentence_loss = 0.026682172

Here is the result of the pre-training's validation part.

I0126 10:53:58.949093 139623278356288 valid_run_EHRpretraining_QA2Seq.py:517] ***** Eval results *****
INFO:tensorflow:    global_step = 5000
I0126 10:53:58.949249 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   global_step = 5000
INFO:tensorflow:    loss = 9.093228
I0126 10:53:58.951513 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   loss = 9.093228
INFO:tensorflow:    masked_lm_accuracy = 0.030574257
I0126 10:53:58.951711 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   masked_lm_accuracy = 0.030574257
INFO:tensorflow:    masked_lm_auc = 0.5872079
I0126 10:53:58.951801 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   masked_lm_auc = 0.5872079
INFO:tensorflow:    masked_lm_loss = 7.6137023
I0126 10:53:58.951881 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   masked_lm_loss = 7.6137023
INFO:tensorflow:    next_sentence_accuracy = 0.7071875
I0126 10:53:58.951961 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   next_sentence_accuracy = 0.7071875
INFO:tensorflow:    next_sentence_auc = 0.6559397
I0126 10:53:58.952096 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   next_sentence_auc = 0.6559397
INFO:tensorflow:    next_sentence_loss = 1.4798392
I0126 10:53:58.952182 139623278356288 valid_run_EHRpretraining_QA2Seq.py:519]   next_sentence_loss = 1.4798392

Here are some results print from task MLM

original masked_lm_log_probs[[-9.17470741 -4.94546 -8.45899...]...]
masked_lm_predictions[278 278 278...]
masked_lm_ids[[826 766 6220...]...]
masked_lm_ids[[1214 27 315...]...]
masked_lm_ids[[1214 27 315...]...]
original masked_lm_log_probs[[-9.18631744 -4.91880751 -8.44931221...]...]
masked_lm_predictions[278 278 278...]
masked_lm_ids[[1631 1379 51...]...]
masked_lm_ids[[1631 1379 51...]...]
original masked_lm_log_probs[[-9.2189 -4.90491104 -8.4063...]...]
masked_lm_predictions[278 278 278...]
masked_lm_ids[[780 2577 2196...]...]
original masked_lm_log_probs[[-9.2146368 -4.90652227 -8.40298271...]...]
masked_lm_predictions[278 278 278...]
masked_lm_ids[[780 2577 2196...]...]
masked_lm_ids[[11616 35223 125...]...]
masked_lm_ids[[11616 35223 125...]...]
lrasmy commented 1 year ago

Hi Zhuzitong,

Yes, that is common. It highly correlate with your vocabulary size and the size of the pre-training dataset.

The key measurement for Med-BERT like pretrained models is the performance boost it lead to on different downstream tasks.

So I'd monitor the model pretraining, and once the MLM loss start to plateau, I frequently test checkpoints on downstream tasks.

Please let me know if you have any further questions.