Zuricho / ParallelFold

Modified version of Alphafold to divide CPU part (MSA and template searching) and GPU part. This can accelerate Alphafold when predicting multiple structures
https://parafold.sjtu.edu.cn
132 stars 44 forks source link

Multimer prediction error: "ValueError: Unable to retrieve parameter 'scale' for module......." #38

Open tclin422 opened 11 months ago

tclin422 commented 11 months ago

Hi,

I've finished the featuring step and got the output feature.pkl for the multiple prediction. Then I got these error messages:

(parafold) [linx@localhost ParallelFold-main]$ I0803 08:18:26.620421 139735959648064 templates.py:857] Using precomputed obsolete pdbs /data/linx/02_Database/01_AF2/pdb_mmcif/obsolete.dat. I0803 08:18:26.864619 139735959648064 xla_bridge.py:603] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host I0803 08:18:26.865048 139735959648064 xla_bridge.py:603] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client' I0803 08:18:30.334805 139735959648064 run_alphafold.py:445] Have 25 models: ['model_1_multimer_pred_0', 'model_1_multimer_pred_1', 'model_1_multimer_pred_2', 'model_1_multimer_pred_3', 'model_1_multimer_pred_4', 'model_2_multimer_pred_0', 'model_2_multimer_pred_1', 'model_2_multimer_pred_2', 'model_2_multimer_pred_3', 'model_2_multimer_pred_4', 'model_3_multimer_pred_0', 'model_3_multimer_pred_1', 'model_3_multimer_pred_2', 'model_3_multimer_pred_3', 'model_3_multimer_pred_4', 'model_4_multimer_pred_0', 'model_4_multimer_pred_1', 'model_4_multimer_pred_2', 'model_4_multimer_pred_3', 'model_4_multimer_pred_4', 'model_5_multimer_pred_0', 'model_5_multimer_pred_1', 'model_5_multimer_pred_2', 'model_5_multimer_pred_3', 'model_5_multimer_pred_4'] I0803 08:18:30.335292 139735959648064 run_alphafold.py:459] Using random seed 338821470640361521 for the data pipeline I0803 08:18:30.335713 139735959648064 run_alphafold.py:189] Predicting R04373_impa I0803 08:18:30.491963 139735959648064 run_alphafold.py:231] Running model model_1_multimer_pred_0 on R04373_impa I0803 08:18:30.492612 139735959648064 model.py:165] Running predict with shape(feat) = {'aatype': (1787,), 'residue_index': (1787,), 'seq_length': (), 'msa': (3072, 1787), 'num_alignments': (), 'template_aatype': (4, 1787), 'template_all_atom_mask': (4, 1787, 37), 'template_all_atom_positions': (4, 1787, 37, 3), 'asym_id': (1787,), 'sym_id': (1787,), 'entity_id': (1787,), 'deletion_matrix': (3072, 1787), 'deletion_mean': (1787,), 'all_atom_mask': (1787, 37), 'all_atom_positions': (1787, 37, 3), 'assembly_num_chains': (), 'entity_mask': (1787,), 'num_templates': (), 'cluster_bias_mask': (3072,), 'bert_mask': (3072, 1787), 'seq_mask': (1787,), 'msa_mask': (3072, 1787)} Traceback (most recent call last): File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 491, in app.run(main) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 464, in main predict_structure( File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 239, in predict_structure prediction_result = model_runner.predict(processed_feature_dict, File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 167, in predict result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 158, in _python_pjit_helper argsflat, , params, in_tree, outtree, = infer_params_fn( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/api.py", line 306, in infer_params return pjit.common_infer_params(pjit_info_args, *args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 505, in common_infer_params jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 971, in _pjit_jaxpr jaxpr, final_consts, out_type = _create_pjit_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 345, in memoized_fun ans = call(fun, args) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/pjit.py", line 924, in _create_pjit_jaxpr jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn out, state = f.apply(params, {}, *args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn out = f(*args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 77, in _forward_fn return model( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 508, in call numrecycles, , prev, safe_key = hk.while_loop( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 898, in while_loop val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1150, in while_loop init_vals, init_avals, body_jaxpr, in_tree, rest = _create_jaxpr(init_val) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 1133, in _create_jaxpr body_jaxpr, body_consts, body_tree = _initial_style_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr jaxpr, consts, out_tree = _initial_style_open_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_openjaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 891, in pure_body_fun val = body_fun(val) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 486, in recycle_body ret = apply_network(prev=prev, safe_key=safe_key2) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 449, in apply_network return impl( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 321, in call repr_shape = hk.eval_shape( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape out_shape = jax.eval_shape(stateless_fun, internal_state(), args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/api.py", line 2783, in eval_shape out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 670, in abstract_evalfun , avalsout, = trace_to_jaxpr_dynamic( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, *kwargs)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun out = fun(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 322, in lambda: embedding_module(batch, is_training)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 696, in call template_act = template_module( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 883, in call summed_templateembeddings, = hk.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan (carry, state), ys = jax.lax.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 250, in scan init_flat, carry_avals, carry_avals_out, init_tree, rest = _create_jaxpr(init) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 236, in _create_jaxpr jaxpr, consts, out_tree = _initial_style_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr jaxpr, consts, out_tree = _initial_style_open_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_openjaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun carry, out = f(carry, x) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 879, in scan_fn return carry + partial_template_embedder(x), None File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 866, in partial_template_embedder return template_embedder(query_embedding, File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1036, in call act, safe_key = template_stack((act, safe_subkey)) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 265, in wrapped ret = _LayerStackNoState( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 156, in call carry, zs = hk.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan (carry, state), ys = jax.lax.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback return fun(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 250, in scan init_flat, carry_avals, carry_avals_out, init_tree, rest = _create_jaxpr(init) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 236, in _create_jaxpr jaxpr, consts, out_tree = _initial_style_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 64, in _initial_style_jaxpr jaxpr, consts, out_tree = _initial_style_open_jaxpr( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 58, in _initial_style_openjaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2155, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/interpreters/partial_eval.py", line 2177, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/jax/_src/linear_util.py", line 188, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun carry, out = f(carry, x) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 149, in layer out_x, z = self._call_wrapped(carry.x, scanned.args_ys) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 182, in _call_wrapped ret = self._f(args) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1022, in template_iteration_fn act = template_iteration( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1084, in call act = dropout_wrapper_fn( File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 76, in dropout_wrapper residual = module(input_act, mask, is_training=is_training, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, *kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1316, in call return self._triangle_multiplication(left_act, left_mask) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1326, in _triangle_multiplication act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True, File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, *kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/common_modules.py", line 176, in call scale = hk.get_parameter( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 448, in wrapped return wrapped._current(*args, **kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 524, in get_parameter raise ValueError( jax._src.traceback_util.UnfilteredStackTrace: ValueError: Unable to retrieve parameter 'scale' for module 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_embedding_iteration/triangle_multiplication_outgoing/layer_norm_input' All parameters must be created as part of init.

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 491, in app.run(main) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 308, in run _run_main(main, args) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main sys.exit(main(argv)) File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 464, in main predict_structure( File "/home/linx/software/02_parafold/ParallelFold-main/run_alphafold.py", line 239, in predict_structure prediction_result = model_runner.predict(processed_feature_dict, File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 167, in predict result = self.apply(self.params, jax.random.PRNGKey(random_seed), feat) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn out, state = f.apply(params, {}, *args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn out = f(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/model.py", line 77, in _forward_fn return model( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 508, in call numrecycles, , prev, safe_key = hk.while_loop( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 898, in while_loop val, state = jax.lax.while_loop(pure_cond_fun, pure_body_fun, init_val) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 891, in pure_body_fun val = body_fun(val) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 486, in recycle_body ret = apply_network(prev=prev, safe_key=safe_key2) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 449, in apply_network return impl( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 321, in call repr_shape = hk.eval_shape( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape out_shape = jax.eval_shape(stateless_fun, internal_state(), args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun out = fun(*args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 322, in lambda: embedding_module(batch, is_training)) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 696, in call template_act = template_module( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 883, in call summed_templateembeddings, = hk.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan (carry, state), ys = jax.lax.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun carry, out = f(carry, x) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 879, in scan_fn return carry + partial_template_embedder(x), None File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 866, in partial_template_embedder return template_embedder(query_embedding, File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1036, in call act, safe_key = template_stack((act, safe_subkey)) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 265, in wrapped ret = _LayerStackNoState( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 156, in call carry, zs = hk.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan (carry, state), ys = jax.lax.scan( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun carry, out = f(carry, x) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 149, in layer out_x, z = self._call_wrapped(carry.x, scanned.args_ys) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/layer_stack.py", line 182, in _call_wrapped ret = self._f(args) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1022, in template_iteration_fn act = template_iteration( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules_multimer.py", line 1084, in call act = dropout_wrapper_fn( File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 76, in dropout_wrapper residual = module(input_act, mask, is_training=is_training, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1316, in call return self._triangle_multiplication(left_act, left_mask) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/modules.py", line 1326, in _triangle_multiplication act = common_modules.LayerNorm(axis=[-1], create_scale=True, create_offset=True, File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/linx/software/02_parafold/ParallelFold-main/alphafold/model/common_modules.py", line 176, in call scale = hk.get_parameter( File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 448, in wrapped return wrapped._current(args, kwargs) File "/home/linx/miniconda3/envs/parafold/lib/python3.8/site-packages/haiku/_src/base.py", line 524, in get_parameter raise ValueError( ValueError: Unable to retrieve parameter 'scale' for module 'alphafold/alphafold_iteration/evoformer/template_embedding/single_template_embedding/template_embedding_iteration/triangle_multiplication_outgoing/layer_norm_input' All parameters must be created as part of init.

Zuricho commented 11 months ago

How did you do the feature step. Could you provide the command to run run_alphafold.sh in feature step?

tclin422 commented 11 months ago

image

gilspeyer commented 8 months ago

I am also getting this error. tclin422, did you find out what caused it?

von-elfen commented 8 months ago

I also got this error, Have you solve that?? This error so weird, i tried many solutions but none works :(

luwei0917 commented 1 month ago

I got the same error and solved it. this is because I was using '-m model_1_multimer,model_2_multimer,model_3_multimer,model_4_multimer,model_5_multimer' command, but actually, I should use '-m model_1_multimer_v3,model_2_multimer_v3,model_3_multimer_v3,model_4_multimer_v3,model_5_multimer_v3'. I was setting the soft link from the model params_model_1_multimer_v3.npz to params_model_1_multimer.npz.(which I shouldn't do)