JeffSHF / ColabDock

Code for ColabDock paper
Other
114 stars 12 forks source link

'Config' object has no attribute 'jax_experimental_name_stack' ,and should i install alphafold2? #3

Closed kehan777 closed 1 year ago

kehan777 commented 1 year ago

Hi, run a script copy from google colab = python colabdocktest.py, here is the config msg: {'template': {'pdb_path': './4HFF/PDB/4HFF.pdb', 'chains': 'A,B'}, 'native': {'pdb_path': './4HFF/PDB/4HFF.pdb', 'chains': 'A,B'}, 'fixed_chains': None, 'res_thres': 8.0, 'non_thres': 12.0, 'rest_1v1': array([[ 78, 198]]), 'rest_1vN': None, 'rest_MvN': None, 'rest_non': None, 'save_path': './results', 'crop_len': None, 'rounds': 2, 'steps': 50, 'save_every_n_step': 1, 'bfloat': True, 'data_dir': '/home/adsb/ColabDock/params'}

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Traceback (most recent call last): File "/home/adsb/ColabDock/colabdocktest.py", line 312, in dock_model.dock_rank() File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdock/model.py", line 108, in dock_rank self.optimize(ith) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdock/docking.py", line 51, in optimize af_model.design(self.step_num, soft=0, temp=1, hard=0, File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 367, in design self.step(lr_scale=lr_scale, backprop=backprop, crop=crop, File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 264, in step self.run(model=model, backprop=backprop, crop=crop, callback=callback) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 175, in run outs.append(self._recycle(p, backprop=backprop)) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 230, in _recycle out = self._single(model_params, backprop) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 218, in _single (loss, aux), grad = self._grad_fn(flags) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper argsflat, , params, in_tree, outtree, = infer_params_fn( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params return pjit.common_infer_params(pjit_info_args, *args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr jaxpr, final_consts, out_type = _create_pjit_jaxpr( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun ans = call(fun, args) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/api.py", line 725, in value_and_grad_f ans, vjp_py, aux = _vjp( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/api.py", line 2213, in _vjp out_primal, out_vjp, aux = ad.vjp( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 141, in vjp out_primals, pvals, jaxpr, consts, aux = linearize(traceable, primals, has_aux=True) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 128, in linearize jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 777, in trace_to_jaxpr_nounits jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/model.py", line 218, in _model outputs = self._runner.apply(model_params, key(), inputs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper argsflat, , params, in_tree, outtree, = infer_params_fn( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params return pjit.common_infer_params(pjit_info_args, *args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr jaxpr, final_consts, out_type = _create_pjit_jaxpr( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun ans = call(fun, args) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/alphafold/model/model.py", line 96, in apply results = self.apply_fn(params, key, feat) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper argsflat, , params, in_tree, outtree, = infer_params_fn( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/api.py", line 306, in infer_params return pjit.common_infer_params(pjit_info_args, *args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 505, in common_infer_params jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr jaxpr, final_consts, out_type = _create_pjit_jaxpr( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun ans = call(fun, args) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn out, state = f.apply(params, {}, *args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn out = f(*args, *kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/alphafold/model/model.py", line 51, in _forward_fn model = modules.AlphaFold(self.config.model) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/module.py", line 123, in call init(module, args, **kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/module.py", line 421, in wrapped if jax.config.jax_experimental_name_stack and module_name: jax._src.traceback_util.UnfilteredStackTrace: AttributeError: 'Config' object has no attribute 'jax_experimental_name_stack'

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/adsb/ColabDock/colabdocktest.py", line 312, in dock_model.dock_rank() File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdock/model.py", line 108, in dock_rank self.optimize(ith) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdock/docking.py", line 51, in optimize af_model.design(self.step_num, soft=0, temp=1, hard=0, File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 367, in design self.step(lr_scale=lr_scale, backprop=backprop, crop=crop, File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 264, in step self.run(model=model, backprop=backprop, crop=crop, callback=callback) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 175, in run outs.append(self._recycle(p, backprop=backprop)) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 230, in _recycle out = self._single(model_params, backprop) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/design.py", line 218, in _single (loss, aux), grad = self._grad_fn(flags) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/model.py", line 218, in _model outputs = self._runner.apply(model_params, key(), inputs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/alphafold/model/model.py", line 96, in apply results = self.apply_fn(params, key, feat) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/transform.py", line 128, in apply_fn out, state = f.apply(params, {}, args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/transform.py", line 357, in apply_fn out = f(*args, *kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/colabdesign/af/alphafold/model/model.py", line 51, in _forward_fn model = modules.AlphaFold(self.config.model) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/module.py", line 123, in call init(module, args, kwargs) File "/home/adsb/miniconda3/envs/colabdock/lib/python3.10/site-packages/haiku/_src/module.py", line 421, in wrapped if jax.config.jax_experimental_name_stack and module_name: AttributeError: 'Config' object has no attribute 'jax_experimental_name_stack'

JeffSHF commented 1 year ago

Hi, no need to install AF2. Only the AF2 params are needed. Could you install the dependency listed in the requirements.txt file manually? I haven't tested if the installation code in the colab notebook can be used locally.

Also, GPU is suggested, otherwise it will take lots of time.

kehan777 commented 1 year ago

Could you provide a examples script of constraints, for 1AHW antigen-antibody docking, 1, Restrict CDRs to possible contact residues; 2, Blocking amino acids outside the 5 Angstrom range of CDRs 3, Does colabdock belong to rigid docking in principle, can some areas be set to be flexible? Such as antibody HCDR3 thank you very much for help ~

JeffSHF commented 1 year ago

For the script, I suggest using the prep_pdb function (https://github.com/JeffSHF/ColabDock/blob/2ed118c15f4eda16b259dad2ae23349594af3523/colabdesign/af/prep.py#L226) provided in colabdesign, to get the coordinates of the beta C atoms. Then, you can calculate the distance matrix of the protein and retrieve the restraints you want. For the 3rd question, the generated structure of colabdock is not necessarily the same as the provided template. The program may modify the template to satisfy the given restraints.

kehan777 commented 1 year ago

https://www.biorxiv.org/content/10.1101/2023.07.04.547599v1.full For 1AHW ... randomly sample 5 interface residues (His91, Gly92 for the light chain, Asp31, Tyr32, Asn100 for the heavy chain) from the antibody and 7 interface residues (Lys165, Thr167, Val192, Thr197, Val198, Asn199, Asp204) from the TF. These sampled residues in the antibody are mainly distributed on L1 CDR, H1 CDR, and H3 CDR loops. The sampled residues in TF are very close to important residues (Tyr156/157, Lys 165/166, Lys169, Arg200, and Lys201) revealed by the pervious mutation assay;

How to set MvN constraint accurately in this case ? I used the following settings, but Failed, rmsd=20.4 rest_MvN = "[[91,(370-424)],[92,(370-424)],[139,(370-424)],[140,(370-424)],[208,(370-424)],2]"

JeffSHF commented 1 year ago

Basically your setting is right. But notice that in [[91,(370-424)],[92,(370-424)],[139,(370-424)],[140,(370-424)],[208,(370-424)],2]", 2 means colabdock only optimizes two of the provided five 1vN restraints. If you are sure that all the five 1vN restraints are right, your rest_MvN should be "[[91,(370-424)],[92,(370-424)],[139,(370-424)],[140,(370-424)],[208,(370-424)],5]".

kehan777 commented 1 year ago

感谢您的耐心解答~👍