VITA-Group / HotProtein

[ICLR 2023] "HotProtein: A Novel Framework for Protein Thermostability Prediction and Editing" by Tianlong Chen*, Chengyue Gong*, Daniel Jesus Diaz, Xuxi Chen, Jordan Tyler Wells, Qiang Liu, Zhangyang Wang, Andrew Ellington, Alex Dimakis, Adam Klivans
MIT License
25 stars 4 forks source link

Upgrade hotprotein framework from ESM1 to ESM2, and train it with deepspeed #2

Open YaoQ opened 9 months ago

YaoQ commented 9 months ago

Thank you for your great work, I am very interested in it.

I have managed to upgrade the hotprotein framework from ESM1 to ESM2, this is the repo I finetune the esm2 models with d1_1_classification.pkl and HP-S2C2 data, then test the mean accuracy for 10 folders of HP-S2C2 dataset.

Also, I finetune the ESM2 models and test it with HP-S dataset.

Then I have several questions:

  1. Why do you train the esm1b and linear model with different loss, optimizer and scheduler. I checked the ESM1b module attributes of each layers' weights, requires_grad is False. that means when we finetune classification function, we just train the linear layer only.

  2. I try to train the hotprotein with deepspeed, now it supports distribute training with multiple GPU in FP16 dtype, but the accuracy is worse. I only test esm2_t6_8M on HP-S2C2 dataset, the best Accuracy for d1_1_classification is only 84.26%, but I finetune esm2_t6_8M classification without deepspeed, it gets 90.36%.

If you have any suggestion, let me know. Thank you very much.

imSeaton commented 9 months ago

Hi, Thank you for your integrating ESM2 work. I have some questions, as follows: 1.Does the key point SAP in the paper only provide the model parameters, such that it's difficult to integrate into esm2? 2.Why the accuracy of the bigger dataset HP-S lower than that of small dataset HP-S2C2? Does it means there is more overfitting problem in the small dataset, even the sequences have less than 50% identity according to the author? I'll appreciate to you reply.

YaoQ commented 9 months ago
  1. SAP According to the Hotprotein paper, you can see how to train the SAP model: We use AlphaFoldDB (Jumper et al., 2021) for SAP protein pre-training. We filter the data with sequence length and data quality and finally get 270K data. ESM-1B (Rives et al., 2019) and ESM-IF (Hsu et al., 2022) backbone are used to process the sequence and 3D coordinate inputs, and an average pooling layer is applied to the final-layer token representations of ESM models and get protein embeddings. A momentum encoder with τ = 1.0, momentum encoder coefficient α = 0.9999 and memory bank of size 65, 536 is used and the model is trained for 4 epochs, with AdamW optimizer, weight decay 10 −12 , batch size 512 and an initial learning rate 10 −6 decayed with OneCycle (Smith & Topin, 2019) decay scheduler.

So SPA is only compatible with ESM-1B, so when I train the classification, I just only use esm-2b pretrain module without SAP.

  1. HP-S Total:182305 train:145844

    • label 0: 5120
    • label 1:28002
    • label 2:24286
    • label 3:63126
    • label 4:25310

test:36461

HP-S2C5 Total: 1040 Train: 936

Test: 104

you see HP-S is much bigger than HP-S2C5 dataset. According to my training results, using a larger model of ESM2 can unearth more sequence information, thereby improving accuracy.
You are right, when I train the esm2_t33_650M on HP-S2C5, it gets 93.75%, when I test it on HP-S, the accuracy is only 38.92%.

imSeaton commented 9 months ago

老哥好呀,我也在深圳。昨天简单看了你的repo,发现你只是把esm的代码照搬过来了,没有get到这篇论文的精髓。他ESM的模型参数是冻住了,但是在每个注意力层的q和v都添加了用于微调的LoRA层(这些层可学习的,Hu E J, Shen Y, Wallis P, et al. Lora: Low-rank adaptation of large language models[J]. arXiv preprint arXiv:2106.09685, 2021.),还有一些sparse层也是可以学习的。建议你看看这个仓库esm/module/Transformers以及esm/sparse_multihead_attention/SparseMultiheadAttention 呢。注意rank参数和use_sparse参数,这些都是微调相关的。祝好运!