ZhiGroup / Med-BERT

Med-BERT, contextualized embedding model for structured EHR data
Apache License 2.0
260 stars 63 forks source link

Possible mismatch between source code and the research paper on Prolonged LOS pre-training task #3

Closed Maitreyapatel closed 4 years ago

Maitreyapatel commented 4 years ago

Hi there,

I was reading the paper and I found that probably the source code contradicts the one and important statement in the research paper.

According to the Med-BERT section in the paper following point is mentioned on page 7:

In BERT, the token [CLS] was used mainly to summarize the information from the two sentences; however, EHR sequences are usually much longer;e.g.,a sequence may contain 10 more visits, and simply using one token will inevitably lead to huge information loss. Therefore, for both the Prolonged LOS task and the downstream disease-prediction tasks where the information of the whole sequence is usually needed, we added a feed-forward layer (FFL) to average the outputs from all of the visits to represent a sequence,instead of using only a single token.

However, the source code is still using the first token for predictions instead of using FFL to average all the outputs. Although, the data preparation does not contains the [CLS] token. Therefore, code at the below location is using the first token which can be the any random ICD-9/10 code according to the input patients. [comment on the same location also mentions that]

https://github.com/ZhiGroup/Med-BERT/blob/ccf65acd175b50b1de10e6fdc65f34fe69e0ccbf/Pretraining%20Code/modeling.py#L227

Is this repository still under construction? Or there is another function which tackles this issue?

If not, then the use of single first token (i.e., ICD-9/10) might increase the noise in the system. Is there any other analysis associated with this such that it makes it easy to understand? Probably the validation loss at each epoch on Prolonged LOS task should also help to understand it better.

Any answers will be helpful and thank you in advance!

lrasmy commented 4 years ago

Hi Maitreya,

Thanks for pinpointing this issue, Yes, there is another pretraining code version that make sure that we use the full sequence and not the first token for the LOS pretraining task. This is file run_EHRpretraining_QA2Seq.py The initial version was following the original [CLS] pretraining code available on Google BERT, and therefore it was giving more power to the first diagnosis in the patient record and supposedly making the pre-training classification task (LOS) more difficult. While the second is the more appropriate implementation using the whole patient sequence information. We compared the performance of the 2 options and the performance of the pre-trained models on the downstream tasks was similar. We are clarifying the same in the paper revision.

Thank you