zrqiao / NeuralPLexer

NeuralPLexer: State-specific protein-ligand complex structure prediction with a multi-scale deep generative model
https://doi.org/10.1038/s42256-024-00792-z
BSD 3-Clause Clear License
238 stars 64 forks source link

torch.cuda.OutOfMemoryError even for a small protein #32

Open venkata-ramaswamy-cresset opened 2 months ago

venkata-ramaswamy-cresset commented 2 months ago

Description

I am trying to generate a protein-ligand complex for a small protein (536 residues) and got torch.cuda.OutOfMemoryError with --cuda. To check if it works with CPUs only, I ran the same command without --cuda and it worked fine. Further, if I reduce the chunk size and samples generated hoping that would reduce the memory consumption, (--n-samples 16 --chunk-size 4 --num-steps=40 to --n-samples 1 --chunk-size 1 --num-steps=1), I don't get torch.cuda.OutOfMemoryError issue and the job runs successfully with --cuda.

Could you please advice what I could do to fix this issue.

Also, as another example, I tried building a protein-ligand complex for a bigger protein (~3000 residues) and NeuralPLexer crashed complaining about memory both with and without --cuda.

Just in case it helps, my workstation has an "NVIDIA GeForce GTX 1660 Ti" GPU (6144 MB) card, 24 CPU cores and a RAM of 32 GB.

What I Did

The command I used for generating the complex for the small protein (536 residues):

neuralplexer-inference --task=batched_structure_sampling --input-receptor DERETWSGKVDFLLSVIGFAVDLANVWRFPYLCYKNGGGAFLVPYGIMLAVGGIPLFYMELALGQHNRKGAITCWGRLVPLFKGIGYAVVLIAFYVDFYYNVIIAWSLRFFFASFTNSLPWTSCNNIWNTPNCRPFEGHVEGFQSAASEYFNRYILELNRSEGIHDLGAIKWDMALCLLIVYLICYFSLWKGISTSGKVVWFTALFPYAVLLILLIRGLTLPGSFLGIQYYLTPNFSAIYKAEVWVDAATQVFFSLGPGFGVLLAYASYNKYHNNVYKDALLTSFINSATSFIAGFVIFSVLGYMAHTLGVRIEDVATEGPGLVFVVYPAAIATMPASTFWALIFFMMLATLGLDSSFGGSEAIITALSDEFPKIKRNRELFVAGLFSLYFVVGLASCTQGGFYFFHLLDRYAAGYSILVAVFFEAIAVSWIYGTNRFSEDIRDMIGFPPGRYWQVCWRFVAPIFLLFITVYGLIGYEPLTYADYVYPSWANALGWCIAGSSVVMIPAVAIFKLLSTPGSLRQRFTILTTPWRDQQ --input-ligand 41X_model_ed.sdf --out-path /data/venkat/test_4xnx_41x --model-checkpoint ./complex_structure_prediction.ckpt --n-samples 16 --chunk-size 4 --num-steps=40 --sampler=langevin_simulated_annealing --cuda

The ligand sdf (txt extension added so I can upload here) with no coordinate info: 41X_model_ed.sdf.txt

I got this error:

Lightning automatically upgraded your loaded checkpoint from v1.7.0 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file complex_structure_prediction.ckpt`
Could not load pretrained MHT weights, skipping
/home/venkat/miniconda3/envs/neuralplexer_dev/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:251: UserWarning: Found keys that are in the model state dict but not in the checkpoint: ['atnum2vdw_uff', 'plm.embed_tokens.weight', 'plm.layers.0.self_attn.k_proj.weight', 'plm.layers.0.self_attn.k_proj.bias', 'plm.layers.0.self_attn.v_proj.weight', 'plm.layers.0.self_attn.v_proj.bias', 'plm.layers.0.self_attn.q_proj.weight', 'plm.layers.0.self_attn.q_proj.bias', 'plm.layers.0.self_attn.out_proj.weight', 'plm.layers.0.self_attn.out_proj.bias', 'plm.layers.0.self_attn.rot_emb.inv_freq', 'plm.layers.0.self_attn_layer_norm.weight', 'plm.layers.0.self_attn_layer_norm.bias', 'plm.layers.0.fc1.weight', 'plm.layers.0.fc1.bias', 'plm.layers.0.fc2.weight', 'plm.layers.0.fc2.bias', 'plm.layers.0.final_layer_norm.weight', 'plm.layers.0.final_layer_norm.bias', 'plm.layers.1.self_attn.k_proj.weight', 'plm.layers.1.self_attn.k_proj.bias', 'plm.layers.1.self_attn.v_proj.weight', 'plm.layers.1.self_attn.v_proj.bias', 'plm.layers.1.self_attn.q_proj.weight', 'plm.layers.1.self_attn.q_proj.bias', 'plm.layers.1.self_attn.out_proj.weight', 'plm.layers.1.self_attn.out_proj.bias', 'plm.layers.1.self_attn.rot_emb.inv_freq', 'plm.layers.1.self_attn_layer_norm.weight', 'plm.layers.1.self_attn_layer_norm.bias', 'plm.layers.1.fc1.weight', 'plm.layers.1.fc1.bias', 'plm.layers.1.fc2.weight', 'plm.layers.1.fc2.bias', 'plm.layers.1.final_layer_norm.weight', 'plm.layers.1.final_layer_norm.bias', 'plm.layers.2.self_attn.k_proj.weight', 'plm.layers.2.self_attn.k_proj.bias', 'plm.layers.2.self_attn.v_proj.weight', 'plm.layers.2.self_attn.v_proj.bias', 'plm.layers.2.self_attn.q_proj.weight', 'plm.layers.2.self_attn.q_proj.bias', 'plm.layers.2.self_attn.out_proj.weight', 'plm.layers.2.self_attn.out_proj.bias', 'plm.layers.2.self_attn.rot_emb.inv_freq', 'plm.layers.2.self_attn_layer_norm.weight', 'plm.layers.2.self_attn_layer_norm.bias', 'plm.layers.2.fc1.weight', 'plm.layers.2.fc1.bias', 'plm.layers.2.fc2.weight', 'plm.layers.2.fc2.bias', 'plm.layers.2.final_layer_norm.weight', 'plm.layers.2.final_layer_norm.bias', 'plm.layers.3.self_attn.k_proj.weight', 'plm.layers.3.self_attn.k_proj.bias', 'plm.layers.3.self_attn.v_proj.weight', 'plm.layers.3.self_attn.v_proj.bias', 'plm.layers.3.self_attn.q_proj.weight', 'plm.layers.3.self_attn.q_proj.bias', 'plm.layers.3.self_attn.out_proj.weight', 'plm.layers.3.self_attn.out_proj.bias', 'plm.layers.3.self_attn.rot_emb.inv_freq', 'plm.layers.3.self_attn_layer_norm.weight', 'plm.layers.3.self_attn_layer_norm.bias', 'plm.layers.3.fc1.weight', 'plm.layers.3.fc1.bias', 'plm.layers.3.fc2.weight', 'plm.layers.3.fc2.bias', 'plm.layers.3.final_layer_norm.weight', 'plm.layers.3.final_layer_norm.bias', 'plm.layers.4.self_attn.k_proj.weight', 'plm.layers.4.self_attn.k_proj.bias', 'plm.layers.4.self_attn.v_proj.weight', 'plm.layers.4.self_attn.v_proj.bias', 'plm.layers.4.self_attn.q_proj.weight', 'plm.layers.4.self_attn.q_proj.bias', 'plm.layers.4.self_attn.out_proj.weight', 'plm.layers.4.self_attn.out_proj.bias', 'plm.layers.4.self_attn.rot_emb.inv_freq', 'plm.layers.4.self_attn_layer_norm.weight', 'plm.layers.4.self_attn_layer_norm.bias', 'plm.layers.4.fc1.weight', 'plm.layers.4.fc1.bias', 'plm.layers.4.fc2.weight', 'plm.layers.4.fc2.bias', 'plm.layers.4.final_layer_norm.weight', 'plm.layers.4.final_layer_norm.bias', 'plm.layers.5.self_attn.k_proj.weight', 'plm.layers.5.self_attn.k_proj.bias', 'plm.layers.5.self_attn.v_proj.weight', 'plm.layers.5.self_attn.v_proj.bias', 'plm.layers.5.self_attn.q_proj.weight', 'plm.layers.5.self_attn.q_proj.bias', 'plm.layers.5.self_attn.out_proj.weight', 'plm.layers.5.self_attn.out_proj.bias', 'plm.layers.5.self_attn.rot_emb.inv_freq', 'plm.layers.5.self_attn_layer_norm.weight', 'plm.layers.5.self_attn_layer_norm.bias', 'plm.layers.5.fc1.weight', 'plm.layers.5.fc1.bias', 'plm.layers.5.fc2.weight', 'plm.layers.5.fc2.bias', 'plm.layers.5.final_layer_norm.weight', 'plm.layers.5.final_layer_norm.bias', 'plm.layers.6.self_attn.k_proj.weight', 'plm.layers.6.self_attn.k_proj.bias', 'plm.layers.6.self_attn.v_proj.weight', 'plm.layers.6.self_attn.v_proj.bias', 'plm.layers.6.self_attn.q_proj.weight', 'plm.layers.6.self_attn.q_proj.bias', 'plm.layers.6.self_attn.out_proj.weight', 'plm.layers.6.self_attn.out_proj.bias', 'plm.layers.6.self_attn.rot_emb.inv_freq', 'plm.layers.6.self_attn_layer_norm.weight', 'plm.layers.6.self_attn_layer_norm.bias', 'plm.layers.6.fc1.weight', 'plm.layers.6.fc1.bias', 'plm.layers.6.fc2.weight', 'plm.layers.6.fc2.bias', 'plm.layers.6.final_layer_norm.weight', 'plm.layers.6.final_layer_norm.bias', 'plm.layers.7.self_attn.k_proj.weight', 'plm.layers.7.self_attn.k_proj.bias', 'plm.layers.7.self_attn.v_proj.weight', 'plm.layers.7.self_attn.v_proj.bias', 'plm.layers.7.self_attn.q_proj.weight', 'plm.layers.7.self_attn.q_proj.bias', 'plm.layers.7.self_attn.out_proj.weight', 'plm.layers.7.self_attn.out_proj.bias', 'plm.layers.7.self_attn.rot_emb.inv_freq', 'plm.layers.7.self_attn_layer_norm.weight', 'plm.layers.7.self_attn_layer_norm.bias', 'plm.layers.7.fc1.weight', 'plm.layers.7.fc1.bias', 'plm.layers.7.fc2.weight', 'plm.layers.7.fc2.bias', 'plm.layers.7.final_layer_norm.weight', 'plm.layers.7.final_layer_norm.bias', 'plm.layers.8.self_attn.k_proj.weight', 'plm.layers.8.self_attn.k_proj.bias', 'plm.layers.8.self_attn.v_proj.weight', 'plm.layers.8.self_attn.v_proj.bias', 'plm.layers.8.self_attn.q_proj.weight', 'plm.layers.8.self_attn.q_proj.bias', 'plm.layers.8.self_attn.out_proj.weight', 'plm.layers.8.self_attn.out_proj.bias', 'plm.layers.8.self_attn.rot_emb.inv_freq', 'plm.layers.8.self_attn_layer_norm.weight', 'plm.layers.8.self_attn_layer_norm.bias', 'plm.layers.8.fc1.weight', 'plm.layers.8.fc1.bias', 'plm.layers.8.fc2.weight', 'plm.layers.8.fc2.bias', 'plm.layers.8.final_layer_norm.weight', 'plm.layers.8.final_layer_norm.bias', 'plm.layers.9.self_attn.k_proj.weight', 'plm.layers.9.self_attn.k_proj.bias', 'plm.layers.9.self_attn.v_proj.weight', 'plm.layers.9.self_attn.v_proj.bias', 'plm.layers.9.self_attn.q_proj.weight', 'plm.layers.9.self_attn.q_proj.bias', 'plm.layers.9.self_attn.out_proj.weight', 'plm.layers.9.self_attn.out_proj.bias', 'plm.layers.9.self_attn.rot_emb.inv_freq', 'plm.layers.9.self_attn_layer_norm.weight', 'plm.layers.9.self_attn_layer_norm.bias', 'plm.layers.9.fc1.weight', 'plm.layers.9.fc1.bias', 'plm.layers.9.fc2.weight', 'plm.layers.9.fc2.bias', 'plm.layers.9.final_layer_norm.weight', 'plm.layers.9.final_layer_norm.bias', 'plm.layers.10.self_attn.k_proj.weight', 'plm.layers.10.self_attn.k_proj.bias', 'plm.layers.10.self_attn.v_proj.weight', 'plm.layers.10.self_attn.v_proj.bias', 'plm.layers.10.self_attn.q_proj.weight', 'plm.layers.10.self_attn.q_proj.bias', 'plm.layers.10.self_attn.out_proj.weight', 'plm.layers.10.self_attn.out_proj.bias', 'plm.layers.10.self_attn.rot_emb.inv_freq', 'plm.layers.10.self_attn_layer_norm.weight', 'plm.layers.10.self_attn_layer_norm.bias', 'plm.layers.10.fc1.weight', 'plm.layers.10.fc1.bias', 'plm.layers.10.fc2.weight', 'plm.layers.10.fc2.bias', 'plm.layers.10.final_layer_norm.weight', 'plm.layers.10.final_layer_norm.bias', 'plm.layers.11.self_attn.k_proj.weight', 'plm.layers.11.self_attn.k_proj.bias', 'plm.layers.11.self_attn.v_proj.weight', 'plm.layers.11.self_attn.v_proj.bias', 'plm.layers.11.self_attn.q_proj.weight', 'plm.layers.11.self_attn.q_proj.bias', 'plm.layers.11.self_attn.out_proj.weight', 'plm.layers.11.self_attn.out_proj.bias', 'plm.layers.11.self_attn.rot_emb.inv_freq', 'plm.layers.11.self_attn_layer_norm.weight', 'plm.layers.11.self_attn_layer_norm.bias', 'plm.layers.11.fc1.weight', 'plm.layers.11.fc1.bias', 'plm.layers.11.fc2.weight', 'plm.layers.11.fc2.bias', 'plm.layers.11.final_layer_norm.weight', 'plm.layers.11.final_layer_norm.bias', 'plm.layers.12.self_attn.k_proj.weight', 'plm.layers.12.self_attn.k_proj.bias', 'plm.layers.12.self_attn.v_proj.weight', 'plm.layers.12.self_attn.v_proj.bias', 'plm.layers.12.self_attn.q_proj.weight', 'plm.layers.12.self_attn.q_proj.bias', 'plm.layers.12.self_attn.out_proj.weight', 'plm.layers.12.self_attn.out_proj.bias', 'plm.layers.12.self_attn.rot_emb.inv_freq', 'plm.layers.12.self_attn_layer_norm.weight', 'plm.layers.12.self_attn_layer_norm.bias', 'plm.layers.12.fc1.weight', 'plm.layers.12.fc1.bias', 'plm.layers.12.fc2.weight', 'plm.layers.12.fc2.bias', 'plm.layers.12.final_layer_norm.weight', 'plm.layers.12.final_layer_norm.bias', 'plm.layers.13.self_attn.k_proj.weight', 'plm.layers.13.self_attn.k_proj.bias', 'plm.layers.13.self_attn.v_proj.weight', 'plm.layers.13.self_attn.v_proj.bias', 'plm.layers.13.self_attn.q_proj.weight', 'plm.layers.13.self_attn.q_proj.bias', 'plm.layers.13.self_attn.out_proj.weight', 'plm.layers.13.self_attn.out_proj.bias', 'plm.layers.13.self_attn.rot_emb.inv_freq', 'plm.layers.13.self_attn_layer_norm.weight', 'plm.layers.13.self_attn_layer_norm.bias', 'plm.layers.13.fc1.weight', 'plm.layers.13.fc1.bias', 'plm.layers.13.fc2.weight', 'plm.layers.13.fc2.bias', 'plm.layers.13.final_layer_norm.weight', 'plm.layers.13.final_layer_norm.bias', 'plm.layers.14.self_attn.k_proj.weight', 'plm.layers.14.self_attn.k_proj.bias', 'plm.layers.14.self_attn.v_proj.weight', 'plm.layers.14.self_attn.v_proj.bias', 'plm.layers.14.self_attn.q_proj.weight', 'plm.layers.14.self_attn.q_proj.bias', 'plm.layers.14.self_attn.out_proj.weight', 'plm.layers.14.self_attn.out_proj.bias', 'plm.layers.14.self_attn.rot_emb.inv_freq', 'plm.layers.14.self_attn_layer_norm.weight', 'plm.layers.14.self_attn_layer_norm.bias', 'plm.layers.14.fc1.weight', 'plm.layers.14.fc1.bias', 'plm.layers.14.fc2.weight', 'plm.layers.14.fc2.bias', 'plm.layers.14.final_layer_norm.weight', 'plm.layers.14.final_layer_norm.bias', 'plm.layers.15.self_attn.k_proj.weight', 'plm.layers.15.self_attn.k_proj.bias', 'plm.layers.15.self_attn.v_proj.weight', 'plm.layers.15.self_attn.v_proj.bias', 'plm.layers.15.self_attn.q_proj.weight', 'plm.layers.15.self_attn.q_proj.bias', 'plm.layers.15.self_attn.out_proj.weight', 'plm.layers.15.self_attn.out_proj.bias', 'plm.layers.15.self_attn.rot_emb.inv_freq', 'plm.layers.15.self_attn_layer_norm.weight', 'plm.layers.15.self_attn_layer_norm.bias', 'plm.layers.15.fc1.weight', 'plm.layers.15.fc1.bias', 'plm.layers.15.fc2.weight', 'plm.layers.15.fc2.bias', 'plm.layers.15.final_layer_norm.weight', 'plm.layers.15.final_layer_norm.bias', 'plm.layers.16.self_attn.k_proj.weight', 'plm.layers.16.self_attn.k_proj.bias', 'plm.layers.16.self_attn.v_proj.weight', 'plm.layers.16.self_attn.v_proj.bias', 'plm.layers.16.self_attn.q_proj.weight', 'plm.layers.16.self_attn.q_proj.bias', 'plm.layers.16.self_attn.out_proj.weight', 'plm.layers.16.self_attn.out_proj.bias', 'plm.layers.16.self_attn.rot_emb.inv_freq', 'plm.layers.16.self_attn_layer_norm.weight', 'plm.layers.16.self_attn_layer_norm.bias', 'plm.layers.16.fc1.weight', 'plm.layers.16.fc1.bias', 'plm.layers.16.fc2.weight', 'plm.layers.16.fc2.bias', 'plm.layers.16.final_layer_norm.weight', 'plm.layers.16.final_layer_norm.bias', 'plm.layers.17.self_attn.k_proj.weight', 'plm.layers.17.self_attn.k_proj.bias', 'plm.layers.17.self_attn.v_proj.weight', 'plm.layers.17.self_attn.v_proj.bias', 'plm.layers.17.self_attn.q_proj.weight', 'plm.layers.17.self_attn.q_proj.bias', 'plm.layers.17.self_attn.out_proj.weight', 'plm.layers.17.self_attn.out_proj.bias', 'plm.layers.17.self_attn.rot_emb.inv_freq', 'plm.layers.17.self_attn_layer_norm.weight', 'plm.layers.17.self_attn_layer_norm.bias', 'plm.layers.17.fc1.weight', 'plm.layers.17.fc1.bias', 'plm.layers.17.fc2.weight', 'plm.layers.17.fc2.bias', 'plm.layers.17.final_layer_norm.weight', 'plm.layers.17.final_layer_norm.bias', 'plm.layers.18.self_attn.k_proj.weight', 'plm.layers.18.self_attn.k_proj.bias', 'plm.layers.18.self_attn.v_proj.weight', 'plm.layers.18.self_attn.v_proj.bias', 'plm.layers.18.self_attn.q_proj.weight', 'plm.layers.18.self_attn.q_proj.bias', 'plm.layers.18.self_attn.out_proj.weight', 'plm.layers.18.self_attn.out_proj.bias', 'plm.layers.18.self_attn.rot_emb.inv_freq', 'plm.layers.18.self_attn_layer_norm.weight', 'plm.layers.18.self_attn_layer_norm.bias', 'plm.layers.18.fc1.weight', 'plm.layers.18.fc1.bias', 'plm.layers.18.fc2.weight', 'plm.layers.18.fc2.bias', 'plm.layers.18.final_layer_norm.weight', 'plm.layers.18.final_layer_norm.bias', 'plm.layers.19.self_attn.k_proj.weight', 'plm.layers.19.self_attn.k_proj.bias', 'plm.layers.19.self_attn.v_proj.weight', 'plm.layers.19.self_attn.v_proj.bias', 'plm.layers.19.self_attn.q_proj.weight', 'plm.layers.19.self_attn.q_proj.bias', 'plm.layers.19.self_attn.out_proj.weight', 'plm.layers.19.self_attn.out_proj.bias', 'plm.layers.19.self_attn.rot_emb.inv_freq', 'plm.layers.19.self_attn_layer_norm.weight', 'plm.layers.19.self_attn_layer_norm.bias', 'plm.layers.19.fc1.weight', 'plm.layers.19.fc1.bias', 'plm.layers.19.fc2.weight', 'plm.layers.19.fc2.bias', 'plm.layers.19.final_layer_norm.weight', 'plm.layers.19.final_layer_norm.bias', 'plm.layers.20.self_attn.k_proj.weight', 'plm.layers.20.self_attn.k_proj.bias', 'plm.layers.20.self_attn.v_proj.weight', 'plm.layers.20.self_attn.v_proj.bias', 'plm.layers.20.self_attn.q_proj.weight', 'plm.layers.20.self_attn.q_proj.bias', 'plm.layers.20.self_attn.out_proj.weight', 'plm.layers.20.self_attn.out_proj.bias', 'plm.layers.20.self_attn.rot_emb.inv_freq', 'plm.layers.20.self_attn_layer_norm.weight', 'plm.layers.20.self_attn_layer_norm.bias', 'plm.layers.20.fc1.weight', 'plm.layers.20.fc1.bias', 'plm.layers.20.fc2.weight', 'plm.layers.20.fc2.bias', 'plm.layers.20.final_layer_norm.weight', 'plm.layers.20.final_layer_norm.bias', 'plm.layers.21.self_attn.k_proj.weight', 'plm.layers.21.self_attn.k_proj.bias', 'plm.layers.21.self_attn.v_proj.weight', 'plm.layers.21.self_attn.v_proj.bias', 'plm.layers.21.self_attn.q_proj.weight', 'plm.layers.21.self_attn.q_proj.bias', 'plm.layers.21.self_attn.out_proj.weight', 'plm.layers.21.self_attn.out_proj.bias', 'plm.layers.21.self_attn.rot_emb.inv_freq', 'plm.layers.21.self_attn_layer_norm.weight', 'plm.layers.21.self_attn_layer_norm.bias', 'plm.layers.21.fc1.weight', 'plm.layers.21.fc1.bias', 'plm.layers.21.fc2.weight', 'plm.layers.21.fc2.bias', 'plm.layers.21.final_layer_norm.weight', 'plm.layers.21.final_layer_norm.bias', 'plm.layers.22.self_attn.k_proj.weight', 'plm.layers.22.self_attn.k_proj.bias', 'plm.layers.22.self_attn.v_proj.weight', 'plm.layers.22.self_attn.v_proj.bias', 'plm.layers.22.self_attn.q_proj.weight', 'plm.layers.22.self_attn.q_proj.bias', 'plm.layers.22.self_attn.out_proj.weight', 'plm.layers.22.self_attn.out_proj.bias', 'plm.layers.22.self_attn.rot_emb.inv_freq', 'plm.layers.22.self_attn_layer_norm.weight', 'plm.layers.22.self_attn_layer_norm.bias', 'plm.layers.22.fc1.weight', 'plm.layers.22.fc1.bias', 'plm.layers.22.fc2.weight', 'plm.layers.22.fc2.bias', 'plm.layers.22.final_layer_norm.weight', 'plm.layers.22.final_layer_norm.bias', 'plm.layers.23.self_attn.k_proj.weight', 'plm.layers.23.self_attn.k_proj.bias', 'plm.layers.23.self_attn.v_proj.weight', 'plm.layers.23.self_attn.v_proj.bias', 'plm.layers.23.self_attn.q_proj.weight', 'plm.layers.23.self_attn.q_proj.bias', 'plm.layers.23.self_attn.out_proj.weight', 'plm.layers.23.self_attn.out_proj.bias', 'plm.layers.23.self_attn.rot_emb.inv_freq', 'plm.layers.23.self_attn_layer_norm.weight', 'plm.layers.23.self_attn_layer_norm.bias', 'plm.layers.23.fc1.weight', 'plm.layers.23.fc1.bias', 'plm.layers.23.fc2.weight', 'plm.layers.23.fc2.bias', 'plm.layers.23.final_layer_norm.weight', 'plm.layers.23.final_layer_norm.bias', 'plm.layers.24.self_attn.k_proj.weight', 'plm.layers.24.self_attn.k_proj.bias', 'plm.layers.24.self_attn.v_proj.weight', 'plm.layers.24.self_attn.v_proj.bias', 'plm.layers.24.self_attn.q_proj.weight', 'plm.layers.24.self_attn.q_proj.bias', 'plm.layers.24.self_attn.out_proj.weight', 'plm.layers.24.self_attn.out_proj.bias', 'plm.layers.24.self_attn.rot_emb.inv_freq', 'plm.layers.24.self_attn_layer_norm.weight', 'plm.layers.24.self_attn_layer_norm.bias', 'plm.layers.24.fc1.weight', 'plm.layers.24.fc1.bias', 'plm.layers.24.fc2.weight', 'plm.layers.24.fc2.bias', 'plm.layers.24.final_layer_norm.weight', 'plm.layers.24.final_layer_norm.bias', 'plm.layers.25.self_attn.k_proj.weight', 'plm.layers.25.self_attn.k_proj.bias', 'plm.layers.25.self_attn.v_proj.weight', 'plm.layers.25.self_attn.v_proj.bias', 'plm.layers.25.self_attn.q_proj.weight', 'plm.layers.25.self_attn.q_proj.bias', 'plm.layers.25.self_attn.out_proj.weight', 'plm.layers.25.self_attn.out_proj.bias', 'plm.layers.25.self_attn.rot_emb.inv_freq', 'plm.layers.25.self_attn_layer_norm.weight', 'plm.layers.25.self_attn_layer_norm.bias', 'plm.layers.25.fc1.weight', 'plm.layers.25.fc1.bias', 'plm.layers.25.fc2.weight', 'plm.layers.25.fc2.bias', 'plm.layers.25.final_layer_norm.weight', 'plm.layers.25.final_layer_norm.bias', 'plm.layers.26.self_attn.k_proj.weight', 'plm.layers.26.self_attn.k_proj.bias', 'plm.layers.26.self_attn.v_proj.weight', 'plm.layers.26.self_attn.v_proj.bias', 'plm.layers.26.self_attn.q_proj.weight', 'plm.layers.26.self_attn.q_proj.bias', 'plm.layers.26.self_attn.out_proj.weight', 'plm.layers.26.self_attn.out_proj.bias', 'plm.layers.26.self_attn.rot_emb.inv_freq', 'plm.layers.26.self_attn_layer_norm.weight', 'plm.layers.26.self_attn_layer_norm.bias', 'plm.layers.26.fc1.weight', 'plm.layers.26.fc1.bias', 'plm.layers.26.fc2.weight', 'plm.layers.26.fc2.bias', 'plm.layers.26.final_layer_norm.weight', 'plm.layers.26.final_layer_norm.bias', 'plm.layers.27.self_attn.k_proj.weight', 'plm.layers.27.self_attn.k_proj.bias', 'plm.layers.27.self_attn.v_proj.weight', 'plm.layers.27.self_attn.v_proj.bias', 'plm.layers.27.self_attn.q_proj.weight', 'plm.layers.27.self_attn.q_proj.bias', 'plm.layers.27.self_attn.out_proj.weight', 'plm.layers.27.self_attn.out_proj.bias', 'plm.layers.27.self_attn.rot_emb.inv_freq', 'plm.layers.27.self_attn_layer_norm.weight', 'plm.layers.27.self_attn_layer_norm.bias', 'plm.layers.27.fc1.weight', 'plm.layers.27.fc1.bias', 'plm.layers.27.fc2.weight', 'plm.layers.27.fc2.bias', 'plm.layers.27.final_layer_norm.weight', 'plm.layers.27.final_layer_norm.bias', 'plm.layers.28.self_attn.k_proj.weight', 'plm.layers.28.self_attn.k_proj.bias', 'plm.layers.28.self_attn.v_proj.weight', 'plm.layers.28.self_attn.v_proj.bias', 'plm.layers.28.self_attn.q_proj.weight', 'plm.layers.28.self_attn.q_proj.bias', 'plm.layers.28.self_attn.out_proj.weight', 'plm.layers.28.self_attn.out_proj.bias', 'plm.layers.28.self_attn.rot_emb.inv_freq', 'plm.layers.28.self_attn_layer_norm.weight', 'plm.layers.28.self_attn_layer_norm.bias', 'plm.layers.28.fc1.weight', 'plm.layers.28.fc1.bias', 'plm.layers.28.fc2.weight', 'plm.layers.28.fc2.bias', 'plm.layers.28.final_layer_norm.weight', 'plm.layers.28.final_layer_norm.bias', 'plm.layers.29.self_attn.k_proj.weight', 'plm.layers.29.self_attn.k_proj.bias', 'plm.layers.29.self_attn.v_proj.weight', 'plm.layers.29.self_attn.v_proj.bias', 'plm.layers.29.self_attn.q_proj.weight', 'plm.layers.29.self_attn.q_proj.bias', 'plm.layers.29.self_attn.out_proj.weight', 'plm.layers.29.self_attn.out_proj.bias', 'plm.layers.29.self_attn.rot_emb.inv_freq', 'plm.layers.29.self_attn_layer_norm.weight', 'plm.layers.29.self_attn_layer_norm.bias', 'plm.layers.29.fc1.weight', 'plm.layers.29.fc1.bias', 'plm.layers.29.fc2.weight', 'plm.layers.29.fc2.bias', 'plm.layers.29.final_layer_norm.weight', 'plm.layers.29.final_layer_norm.bias', 'plm.layers.30.self_attn.k_proj.weight', 'plm.layers.30.self_attn.k_proj.bias', 'plm.layers.30.self_attn.v_proj.weight', 'plm.layers.30.self_attn.v_proj.bias', 'plm.layers.30.self_attn.q_proj.weight', 'plm.layers.30.self_attn.q_proj.bias', 'plm.layers.30.self_attn.out_proj.weight', 'plm.layers.30.self_attn.out_proj.bias', 'plm.layers.30.self_attn.rot_emb.inv_freq', 'plm.layers.30.self_attn_layer_norm.weight', 'plm.layers.30.self_attn_layer_norm.bias', 'plm.layers.30.fc1.weight', 'plm.layers.30.fc1.bias', 'plm.layers.30.fc2.weight', 'plm.layers.30.fc2.bias', 'plm.layers.30.final_layer_norm.weight', 'plm.layers.30.final_layer_norm.bias', 'plm.layers.31.self_attn.k_proj.weight', 'plm.layers.31.self_attn.k_proj.bias', 'plm.layers.31.self_attn.v_proj.weight', 'plm.layers.31.self_attn.v_proj.bias', 'plm.layers.31.self_attn.q_proj.weight', 'plm.layers.31.self_attn.q_proj.bias', 'plm.layers.31.self_attn.out_proj.weight', 'plm.layers.31.self_attn.out_proj.bias', 'plm.layers.31.self_attn.rot_emb.inv_freq', 'plm.layers.31.self_attn_layer_norm.weight', 'plm.layers.31.self_attn_layer_norm.bias', 'plm.layers.31.fc1.weight', 'plm.layers.31.fc1.bias', 'plm.layers.31.fc2.weight', 'plm.layers.31.fc2.bias', 'plm.layers.31.final_layer_norm.weight', 'plm.layers.31.final_layer_norm.bias', 'plm.layers.32.self_attn.k_proj.weight', 'plm.layers.32.self_attn.k_proj.bias', 'plm.layers.32.self_attn.v_proj.weight', 'plm.layers.32.self_attn.v_proj.bias', 'plm.layers.32.self_attn.q_proj.weight', 'plm.layers.32.self_attn.q_proj.bias', 'plm.layers.32.self_attn.out_proj.weight', 'plm.layers.32.self_attn.out_proj.bias', 'plm.layers.32.self_attn.rot_emb.inv_freq', 'plm.layers.32.self_attn_layer_norm.weight', 'plm.layers.32.self_attn_layer_norm.bias', 'plm.layers.32.fc1.weight', 'plm.layers.32.fc1.bias', 'plm.layers.32.fc2.weight', 'plm.layers.32.fc2.bias', 'plm.layers.32.final_layer_norm.weight', 'plm.layers.32.final_layer_norm.bias', 'plm.contact_head.regression.weight', 'plm.contact_head.regression.bias', 'plm.emb_layer_norm_after.weight', 'plm.emb_layer_norm_after.bias', 'plm.lm_head.weight', 'plm.lm_head.bias', 'plm.lm_head.dense.weight', 'plm.lm_head.dense.bias', 'plm.lm_head.layer_norm.weight', 'plm.lm_head.layer_norm.bias', 'protein_encoder.template_binding_site_enc.weight', 'pl_contact_stack.template_binding_site_enc.weight']
  rank_zero_warn(
Lightning automatically upgraded your loaded checkpoint from v1.7.0 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file complex_structure_prediction.ckpt`
Could not load pretrained MHT weights, skipping
/home/venkat/miniconda3/envs/neuralplexer_dev/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:251: UserWarning: Found keys that are in the model state dict but not in the checkpoint: ['atnum2vdw_uff', 'plm.embed_tokens.weight', ...]
  rank_zero_warn(
/data/venkat/NeuralPLexer/neuralplexer/inference.py:667: UserWarning: Assuming the provided receptor input is a protein sequence
  warnings.warn("Assuming the provided receptor input is a protein sequence")
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:148: RuntimeWarning: invalid value encountered in divide
  in_vec = in_bond_vec / np.linalg.norm(in_bond_vec)
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:152: RuntimeWarning: invalid value encountered in divide
  out_vec = out_bond_vec / np.linalg.norm(out_bond_vec)
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:172: RuntimeWarning: invalid value encountered in divide
  query_bond_vec = query_bond_vec / np.linalg.norm(query_bond_vec)
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:205: RuntimeWarning: invalid value encountered in divide
  query_bonded_vec = query_bonded_vec / np.linalg.norm(query_bonded_vec)
Block contact sampling:  47%|█████████████████████████████████████████████████████████████████████▊                                                                               | 15/32 [00:04<00:04,  3.43it/s]

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
neuralplexer-inference 33 <module>
sys.exit(load_entry_point('neuralplexer', 'console_scripts', 'neuralplexer-inference')())

inference.py 672 main
multi_pose_sampling(

inference.py 182 multi_pose_sampling
output_struct = model.sample_pl_complex_structures(

wrappers.py 3104 sample_pl_complex_structures
batch = self.forward(

wrappers.py 204 forward
self._infer_geometry_prior(batch, **kwargs)

wrappers.py 688 _infer_geometry_prior
self._run_contact_map_stack(

wrappers.py 515 _run_contact_map_stack
batch = self.pl_contact_stack(

module.py 1194 _call_impl
return forward_call(*input, **kwargs)

cpm.py 649 forward
_, merged_grid_rep = self.triangle_stacks[block_id](

module.py 1194 _call_impl
return forward_call(*input, **kwargs)

modules.py 417 forward
x1_out = self.mha_1(q_x=x1, kv_x=x2, biases=biases_J2I, use_lma=use_lma)

module.py 1194 _call_impl
return forward_call(*input, **kwargs)

primitives.py 504 forward
o = _attention(q, k, v, biases)

primitives.py 253 _attention
a = softmax_no_cast(a, -1)

primitives.py 237 softmax_no_cast
s = torch.nn.functional.softmax(t, dim=dim)

functional.py 1841 softmax
ret = input.softmax(dim)

torch.cuda.OutOfMemoryError:
CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 5.79 GiB total capacity; 4.58 GiB already allocated; 162.75 MiB free; 5.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The command for modelling the bigger protein:

neuralplexer-inference --task=batched_structure_sampling --input-receptor ./4dx5_ed.pdb --input-ligand MIY_model_ed.sdf --out-path /data/venkat/test_4dx5_MIY --model-checkpoint ./complex_structure_prediction.ckpt --n-samples 1 --chunk-size 1 --num-steps=1 --sampler=langevin_simulated_annealing --cuda

The relevant input files are: MIY_model_ed.sdf.txt 4dx5_ed.pdb.txt

The output I got

Lightning automatically upgraded your loaded checkpoint from v1.7.0 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file complex_structure_prediction.ckpt`
Could not load pretrained MHT weights, skipping
/home/venkat/miniconda3/envs/neuralplexer_dev/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:251: UserWarning: Found keys that are in the model state dict but not in the checkpoint: ['atnum2vdw_uff', 'plm.embed_tokens.weight', ...]
  rank_zero_warn(
Lightning automatically upgraded your loaded checkpoint from v1.7.0 to v1.9.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file complex_structure_prediction.ckpt`
Could not load pretrained MHT weights, skipping
/home/venkat/miniconda3/envs/neuralplexer_dev/lib/python3.10/site-packages/pytorch_lightning/core/saving.py:251: UserWarning: Found keys that are in the model state dict but not in the checkpoint: ['atnum2vdw_uff', 'plm.embed_tokens.weight', ...]
  rank_zero_warn(
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:148: RuntimeWarning: invalid value encountered in divide
  in_vec = in_bond_vec / np.linalg.norm(in_bond_vec)
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:152: RuntimeWarning: invalid value encountered in divide
  out_vec = out_bond_vec / np.linalg.norm(out_bond_vec)
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:172: RuntimeWarning: invalid value encountered in divide
  query_bond_vec = query_bond_vec / np.linalg.norm(query_bond_vec)
/data/venkat/NeuralPLexer/neuralplexer/data/molops.py:205: RuntimeWarning: invalid value encountered in divide
  query_bonded_vec = query_bonded_vec / np.linalg.norm(query_bonded_vec)

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
neuralplexer-inference 33 <module>
sys.exit(load_entry_point('neuralplexer', 'console_scripts', 'neuralplexer-inference')())

inference.py 672 main
multi_pose_sampling(

inference.py 182 multi_pose_sampling
output_struct = model.sample_pl_complex_structures(

wrappers.py 3104 sample_pl_complex_structures
batch = self.forward(

wrappers.py 187 forward
batch = self._run_encoder_stack(

wrappers.py 457 _run_encoder_stack
self._initialize_protatm_indexer_and_embeddings(batch)

wrappers.py 1657 _initialize_protatm_indexer_and_embeddings
.contiguous()[same_residue_mask]

torch.cuda.OutOfMemoryError:
CUDA out of memory. Tried to allocate 4.17 GiB (GPU 0; 5.79 GiB total capacity; 3.84 GiB already allocated; 1.51 GiB free; 3.93 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
venkata-ramaswamy-cresset commented 2 months ago

Do you have any recommendations for the minimal resources required for neuralplexer in terms of CPU, RAM, GPU vs. size of input protein?

zrqiao commented 2 months ago

5.8Gb total capacity is certainly below the GPU RAM we have tested on. As a minimum we recommend testing on a system with at least 16GB CPU and GPU memory.

venkata-ramaswamy-cresset commented 2 months ago

Thank you @zrqiao but even with 32 GB RAM, I could not generate the models for this example case and it returned FileNotFoundError for prot_all.pdb that has to be generated.

Command: neuralplexer-inference --task=batched_structure_sampling --input-receptor ./4dx5_ed.pdb --input-ligand MIY_model_ed.sdf --out-path ./testing_4dx5_MIY --model-checkpoint ./complex_structure_prediction.ckpt --n-samples 1 --chunk-size 1 --num-steps=1 --sampler=langevin_simulated_annealing

Output ` Block contact sampling: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 31/31 [01:06<00:00, 2.14s/it] Structure generation using langevin_simulated_annealing: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:21<00:00, 21.96s/it] /home/venkata/miniconda3/envs/neuralplexer_dev/lib/python3.10/site-packages/pytorch3d/ops/points_alignment.py:340: UserWarning: Excessively low rank of cross-correlation between aligned point clouds. corresponding_points_alignment cannot return a unique rotation. warnings.warn(


neuralplexer-inference 33 sys.exit(load_entry_point('neuralplexer', 'console_scripts', 'neuralplexer-inference')())

inference.py 672 main multi_pose_sampling(

inference.py 257 multi_pose_sampling write_pdb_models(

pipeline.py 971 write_pdb_models with open(out_path, "w") as of:

FileNotFoundError: 2 No such file or directory testing_4dx5_MIY/prot_all.pdb `

Any pointers on why I get this FileNotFoundError that Neuralplexer has to generate? It worked for smaller proteins but not this one.

These are the input files (added .txt extension to upload here) if needed to reproduce the error: 4dx5_ed.pdb.txt MIY_model_ed.sdf.txt