FreshAirTonight / af2complex

Predicting direct protein-protein interactions with AlphaFold deep learning neural network models.
146 stars 19 forks source link

the runtime of the procedure #7

Closed Qmi3 closed 2 years ago

Qmi3 commented 2 years ago

when i input a pair of proteins with total length ~ 1200 and run the monomer_ptm model ,i have almost spent 4~5 hours to receive the result. Is it normal? i test this example on nvidia A100 and spent also 50mins on the test example you give.

FreshAirTonight commented 2 years ago

If you ran all five monomer_ptm models and each up to 20 recycles on your 1,200AA target, it is possible to take 4 to 5 hours. If it is just a single model with short recycles, then the run time seems too long for an A100.

Qmi3 commented 2 years ago

Maybe there is some problem with my device qaq. this is the output of my attempt: # Info: input feature directory is af2c_fea Info: result output directory is af2c_mod Info: model preset is monomer_ptm /lustre/grp/gyqlab/zhangcw/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/absl/flags/_validators.py:229: UserWarning: Flag --model_preset has a non-None default value; therefore, mark_flag_as_required will pass even if flag is not specified in the command line! mark_flag_as_required(flag_name, flag_values) Info: using preset economy Info: set num_ensemble = 1 Info: set max_recyles = 3 Info: set recycle_tol = 0.1 I0809 14:30:32.669277 140223994627904 xla_bridge.py:333] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: I0809 14:30:33.038053 140223994627904 xla_bridge.py:333] Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not available. I0809 14:30:34.307948 140223994627904 run_af2c_mod.py:481] Have 1 models: ['model_1_ptm'] Info: working on target test I0809 14:30:37.311321 140223994627904 run_af2c_mod.py:495] Using random seed 1806477629473041046 for the data pipeline Info: test found monomer Q86WB0 msa_depth = 3444, seq_len = 502, num_templ = 3 Info: test found monomer Q9HB96 msa_depth = 1236, seq_len = 536, num_templ = 11 Info: Q9HB96 reducing the number of structural templates to 4 Info: 2 chain(s) to model {'A': 'Q86WB0_1', 'B': 'Q9HB96_1'} Info: modeling test with msa_depth = 4680, seq_len = 1038, num_templ = 7 I0809 14:30:37.574126 140223994627904 run_af2c_mod.py:243] Running model model_1_ptm_20220809_041046 I0809 14:30:51.070628 140223994627904 model.py:188] Running predict with shape(feat) = {'aatype': (4, 1038), 'seq_length': (4,), 'template_aatype': (4, 4, 1038), 'template_all_atom_masks': (4, 4, 1038, 37), 'template_all_atom_positions': (4, 4, 1038, 37, 3), 'template_sum_probs': (4, 4, 1), 'residue_index': (4, 1038), 'is_distillation': (4,), 'seq_mask': (4, 1038), 'msa_mask': (4, 252, 1038), 'msa_row_mask': (4, 252), 'random_crop_to_size_seed': (4, 2), 'template_mask': (4, 4), 'template_pseudo_beta': (4, 4, 1038, 3), 'template_pseudo_beta_mask': (4, 4, 1038), 'atom14_atom_exists': (4, 1038, 14), 'residx_atom14_to_atom37': (4, 1038, 14), 'residx_atom37_to_atom14': (4, 1038, 37), 'atom37_atom_exists': (4, 1038, 37), 'extra_msa': (4, 512, 1038), 'extra_msa_mask': (4, 512, 1038), 'extra_msa_row_mask': (4, 512), 'bert_mask': (4, 252, 1038), 'true_msa': (4, 252, 1038), 'extra_has_deletion': (4, 512, 1038), 'extra_deletion_value': (4, 512, 1038), 'msa_feat': (4, 252, 1038, 49), 'target_feat': (4, 1038, 22)} /lustre/grp/gyqlab/zhangcw/miniconda3/envs/pytorch_1/lib/python3.8/site-packages/jax/_src/tree_util.py:200: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement. warnings.warn('jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() ' I0809 19:19:02.602297 140223994627904 run_af2c_mod.py:295] Total JAX model model_1_ptm_20220809_041046 predict time (includes compilation time): 17291.5 seconds Info: test model_1_ptm_20220809_041046_recycled_00, tol = inf, pLDDT = 65.87, pTM-score = 0.3851, piTM-score = 0.3480, interface-score = 0.0989, iRes = 309 iCnt = 691 Info: test model_1_ptm_20220809_041046_recycled_01, tol = 52.18, pLDDT = 65.80, pTM-score = 0.3983, piTM-score = 0.4903, interface-score = 0.0917, iRes = 242 iCnt = 597 Info: test model_1_ptm_20220809_041046_recycled_02, tol = 16.17, pLDDT = 67.11, pTM-score = 0.4014, piTM-score = 0.1953, interface-score = 0.0152, iRes = 35 iCnt = 34 Info: test model_1_ptm_20220809_041046 performed 3 recycles, final tol = 11.02, pLDDT = 67.11, pTM-score = 0.4079, piTM-score = 0.1942, interface-score = 0.0172, iRes = 33 iCnt = 31 I0809 19:19:49.212634 140223994627904 run_af2c_mod.py:411] Final timings for test: {'features': 0.26, 'process_features_model_1_ptm_20220809_041046': 13.5, 'predict_and_compile_model_1_ptm_20220809_041046': 17291.5}

real 291m27.524s user 334m57.277s sys 91m36.683s

FreshAirTonight commented 2 years ago

It appears that you were running one model with only three recycles on an A100 workstation. If this is case, you should get your results very quickly, less than 15 minutes, for a target with ~1000AAs. Three-hour on an A100 can model a 2700AA target for 6 recycles.

You may want to first check if the run is on the GPU. Command such as nvidia-smi is handy for a quick check. GPU utility during the run should be 100% or close to 100%.

Qmi3 commented 2 years ago

Yeah, i found the GPU utility was too low and sloved this problem by downgrade the version jax & jaxlib inspired by the previous issues Very low gpu utilization when running examples #3 . Thank you for your reply!