Modified version of Alphafold to divide CPU part (MSA and template searching) and GPU part. This can accelerate Alphafold when predicting multiple structures
Hi,
after installation the "CPU part" (jackhammer and hhblits) work well.
But when i start the gpu part, i've got this error message:
TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.
Full error message:
File "/softwares/alphafold/run_alphafold.py", line 316, in
app.run(main)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/softwares/alphafold/run_alphafold.py", line 289, in main
predict_structure(
File "/softwares/alphafold/run_alphafold.py", line 188, in predict_structure
relaxed_pdbstr, , _ = amber_relaxer.process(prot=unrelaxed_protein)
File "/softwares/alphafold/alphafold/relax/relax.py", line 58, in process
out = amber_minimize.run_pipeline(
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 482, in run_pipeline
ret.update(get_violation_metrics(prot))
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 356, in get_violation_metrics
structural_violations, struct_metrics = find_violations(prot)
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 338, in find_violations
violations = folding.find_structural_violations(
File "/softwares/alphafold/alphafold/model/folding.py", line 757, in find_structural_violations
atom14_atom_radius = batch['atom14_atom_exists'] utils.batched_gather(
File "/softwares/alphafold/alphafold/model/utils.py", line 39, in batched_gather
return take_fn(params, indices)
File "/softwares/alphafold/alphafold/model/utils.py", line 36, in
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5383, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(args, kwargs)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/api.py", line 411, in cache_miss
out_flat = xla.xla_call(
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1618, in bind
return call_bind(self, fun, *args, *params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1609, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1621, in process
return trace.process_call(self, fun, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 615, in process_call
return primitive.impl(f, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 622, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 694, in _xla_callable
return lower_xla_callable(fun, device, backend, name, donated_invars, arg_specs).compile().unsafe_call
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 702, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1522, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1500, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(in_tracers)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5390, in _take
_check_arraylike("take", a)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 559, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/softwares/alphafold/run_alphafold.py", line 316, in
app.run(main)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/softwares/alphafold/run_alphafold.py", line 289, in main
predict_structure(
File "/softwares/alphafold/run_alphafold.py", line 188, in predict_structure
relaxed_pdbstr, , _ = amber_relaxer.process(prot=unrelaxed_protein)
File "/softwares/alphafold/alphafold/relax/relax.py", line 58, in process
out = amber_minimize.run_pipeline(
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 482, in run_pipeline
ret.update(get_violation_metrics(prot))
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 356, in get_violation_metrics
structural_violations, struct_metrics = find_violations(prot)
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 338, in find_violations
violations = folding.find_structural_violations(
File "/softwares/alphafold/alphafold/model/folding.py", line 757, in find_structural_violations
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
File "/softwares/alphafold/alphafold/model/utils.py", line 39, in batched_gather
return take_fn(params, indices)
File "/softwares/alphafold/alphafold/model/utils.py", line 36, in
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5383, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5390, in _take
_check_arraylike("take", a)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 559, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.
Hi, after installation the "CPU part" (jackhammer and hhblits) work well. But when i start the gpu part, i've got this error message: TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.
1st part: ./run_feature.sh -d data -o ./tmp -m model_1,model_2,model_3,model_4,model_5 -f ./query/1crn.fasta -t 2021-07-27 2st part: ./run_alphafold.sh -d data -o ./tmp -m model_1,model_2,model_3,model_4,model_5 -f ./query/1crn.fasta -t 2021-07-27
Full error message: File "/softwares/alphafold/run_alphafold.py", line 316, in
app.run(main)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/softwares/alphafold/run_alphafold.py", line 289, in main
predict_structure(
File "/softwares/alphafold/run_alphafold.py", line 188, in predict_structure
relaxed_pdbstr, , _ = amber_relaxer.process(prot=unrelaxed_protein)
File "/softwares/alphafold/alphafold/relax/relax.py", line 58, in process
out = amber_minimize.run_pipeline(
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 482, in run_pipeline
ret.update(get_violation_metrics(prot))
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 356, in get_violation_metrics
structural_violations, struct_metrics = find_violations(prot)
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 338, in find_violations
violations = folding.find_structural_violations(
File "/softwares/alphafold/alphafold/model/folding.py", line 757, in find_structural_violations
atom14_atom_radius = batch['atom14_atom_exists'] utils.batched_gather(
File "/softwares/alphafold/alphafold/model/utils.py", line 39, in batched_gather
return take_fn(params, indices)
File "/softwares/alphafold/alphafold/model/utils.py", line 36, in
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5383, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun( args, kwargs)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/api.py", line 411, in cache_miss
out_flat = xla.xla_call(
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1618, in bind
return call_bind(self, fun, *args, *params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1609, in call_bind
outs = primitive.process(top_trace, fun, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 1621, in process
return trace.process_call(self, fun, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/core.py", line 615, in process_call
return primitive.impl(f, tracers, params)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 622, in _xla_call_impl
compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/linear_util.py", line 262, in memoized_fun
ans = call(fun, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 694, in _xla_callable
return lower_xla_callable(fun, device, backend, name, donated_invars, arg_specs).compile().unsafe_call
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/xla.py", line 702, in lower_xla_callable
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1522, in trace_to_jaxpr_final
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1500, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(in_tracers)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5390, in _take
_check_arraylike("take", a)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 559, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/softwares/alphafold/run_alphafold.py", line 316, in
app.run(main)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/softwares/alphafold/run_alphafold.py", line 289, in main
predict_structure(
File "/softwares/alphafold/run_alphafold.py", line 188, in predict_structure
relaxed_pdbstr, , _ = amber_relaxer.process(prot=unrelaxed_protein)
File "/softwares/alphafold/alphafold/relax/relax.py", line 58, in process
out = amber_minimize.run_pipeline(
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 482, in run_pipeline
ret.update(get_violation_metrics(prot))
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 356, in get_violation_metrics
structural_violations, struct_metrics = find_violations(prot)
File "/softwares/alphafold/alphafold/relax/amber_minimize.py", line 338, in find_violations
violations = folding.find_structural_violations(
File "/softwares/alphafold/alphafold/model/folding.py", line 757, in find_structural_violations
atom14_atom_radius = batch['atom14_atom_exists'] * utils.batched_gather(
File "/softwares/alphafold/alphafold/model/utils.py", line 39, in batched_gather
return take_fn(params, indices)
File "/softwares/alphafold/alphafold/model/utils.py", line 36, in
take_fn = lambda p, i: jnp.take(p, i, axis=axis)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5383, in take
return _take(a, indices, None if axis is None else operator.index(axis), out,
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5390, in _take
_check_arraylike("take", a)
File "/softwares/alphafold/envs/alphafold/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 559, in _check_arraylike
raise TypeError(msg.format(fun_name, type(arg), pos))
TypeError: take requires ndarray or scalar arguments, got <class 'list'> at position 0.