bytedance / dplm

Official Implemetation of DPLM (ICML'24) - Diffusion Language Models Are Versatile Protein Learners
https://bytedance.github.io/dplm/
Apache License 2.0
76 stars 8 forks source link

Result Mismatch when dealing with HumanPPI prediction #8

Closed hhhhh789 closed 1 month ago

hhhhh789 commented 2 months ago

Since the author provide best ckpts for represent learning task, I've followed those to reach similar results for DPLM 650M. However, I found the HumanPPI Accuracy much higher than the paper recorded(nearly 5~10%):

DPLM paper: 80.98% for 150M, 86.41% for 650M DPLM implied: 89.44% for 150M, 91.66% for 650M

Although it seems that higher results mean better, I'm wondering if there's bugs in current codes causing this bias. For 650M result, I just run the run_test script loading the ckpt from hugglingface. For 150M results, we tunned the model as suggested in readme with DPLM 150M model loaded after 20epochs.

In Addition, I wanted to confirm that all HumanPPI results are reached with frozen backbone options on. It's weird that all other tasks such as Thermostability and MetalIonBinding have no frozon backbones. Are there any reasons to do so?

wxy-nlp commented 2 months ago

Hi @hhhhh789, Thank you for reproducing our results. This very high result is really weird, and after debugging I found the reason. The "freeze_backbone" option in config/HumanPPI/dplm.yaml is specified to true, so the model/dplm/dplm_ppi_model.py will get hidden state from the self.get_hidden_state method, https://github.com/bytedance/dplm/blob/e102c3893d74da09c1bf884cd20bca7d37612c43/model/dplm/dplm_ppi_model.py#L35-L37 which averages all hidden states in the input sequence to one hidden state: https://github.com/bytedance/dplm/blob/e102c3893d74da09c1bf884cd20bca7d37612c43/model/dplm/base.py#L157-L159

However, when the "freeze_backbone" is set to false, the model/dplm/dplm_ppi_model.py will just load the hidden state of the <cls> token as the final hidden state. https://github.com/bytedance/dplm/blob/e102c3893d74da09c1bf884cd20bca7d37612c43/model/dplm/dplm_ppi_model.py#L35-L40 I hypothesis that the average of all hidden states contains more information than the <cls> token only, so the result is super high. And when I tune the "freeze_backbone" to false, I get the 87.22% result for DPLM 650M checkpoint provided in huggingface.

The checkpoints in huggingface and the results of DPLM 650M in our paper are obtained by setting the freeze_backbone to false and taking the representation of <cls> token only. But recently I discover the HumanPPI task will suffer the overfitting problem so I set the freeze_backbone to true to mitigate this. Interestingly, I never find that checkpoint trained by taking the <cls> representation only will achieve higher results by taking the average of all representations during inference. I am not sure this phenomenon is common or not.