sokrypton / ColabDesign

Making Protein Design accessible to all via Google Colab!
Other
566 stars 132 forks source link

converting other MPNN checkpoints for support in ColabDesign and JAX #165

Open adrienchaton opened 10 months ago

adrienchaton commented 10 months ago

Hi everyone and thanks for the great work you are sharing!

It is awesome that ProteinMPNN is integrated within ColabDesign. I assume you have converted the original pytorch checkpoints to JAX ... or have you retrained the models?

In the case you have converted the checkpoints, I would like to ask if the conversion method could be shared please so that I could also convert other protein MPNN models that were fine-tuned from the original checkpoint using the original pytorch code.

I have this one in mind, https://zenodo.org/records/8164693 and it would be awesome if possible to convert the abmpnn.pt checkpoint to a .pkl compatible with ColabDesign ...

Any help would be much appreciated, thanks!

adrienchaton commented 9 months ago

@sokrypton trying my luck ... any chance one of you have a conversion script for ProteinMPNN checkpoints from the official PyTorch repo to JAX for using in your (great) ColabDesign repo? Thanks!

sokrypton commented 9 months ago

@adrienchaton I apologize for the delay. Was travelling. Here is the script we used to convert weights: https://github.com/sokrypton/ColabDesign/tree/main/mpnn/convert_weights

adrienchaton commented 9 months ago

@sokrypton thanks a lot, I will run that and upload the weights.

I didn't see this script as I was digging https://github.com/sokrypton/ColabDesign/tree/main/colabdesign/mpnn

But in the meantime I was figuring out it is a matter of matching dict keys (with nested w,b) and converting parameters to numpy arrays. Nonetheless, you saving me some work to get it right.

Much appreciated and best end of year wishes to you.

sokrypton commented 9 months ago

I just upload the script 🤪 Tell me if you run into any issues.

adrienchaton commented 9 months ago

@sokrypton thanks for clarifying, I didn't check the commit history but was surprised I hadn't seen that before asking. The script worked flawlessly and the checkpoint is running correctly, I didn't do assertions against the original pt checkpoint and code, but sequence recovery on IF against Ab backbone is high so it should be correct, i.e. with low temperature it is mostly >80%. Would you like to add this checkpoint to the repo?

AFAIK, there is only another IF model for Abs, finetuned from the ESM one by the same lab (oxpig). AntiFold feels maybe more interesting (possibly, going a bit less towards the germline) but the MPNN and ColabDesign workflows are greater and much more interesting as protein design tools. Bottom line, having both is great for me and I would imagine it's a relevant addition to the repo.

sokrypton commented 9 months ago

Awesome! Tell me if you wanna contribute the converted weights, would be happy to add them to the repo :D

adrienchaton commented 9 months ago

Hi @sokrypton sorry for the delayed answer, sounds good!

FYI the fine-tuned pt weights are shared here https://zenodo.org/records/8164693 under the Creative Commons Attribution 4.0 license. Original MPNN codes and models are MIT license, both are quite open licenses but we probably want to mention that still.

I used already quite a bit the converted checkpoint and outputs seem correct How would you like me to do that?

Let me know and I will prepare a pull-request. Cheers

adrienchaton commented 9 months ago

In case that's helpful to anyone, for now the converted checkpoint is here! Feel free to bring it to the repo however you wish to or to let me know any actions wished from my end, thanks again for the conversion script!

abmpnn.pkl.zip