sokrypton / ColabFold

Making Protein folding accessible to all!
MIT License
1.95k stars 493 forks source link

Folding jobs fail repeatedly (one example of many with the same error) #378

Open ernst-schmid opened 1 year ago

ernst-schmid commented 1 year ago

Expected Behavior

Batch is able to run through all the queries in a CSV file

Current Behavior

Stops running at certain sequences that cause an internal issue.

input which caused failure:

MSPVPSVYDLLLRLVGHQEVLKDPAQMGKAIETLGKQIEEYKPGSFSSLISGKINSFQFTIPSFVNLLVYFLIFYVSLLVVNKTTRVMLTLLKSLATVAILLLIVCMGVLLLLR:MSPVPSVYDLLLRLVGHQEVLKDPAQMGKAIETLGKQIEEYKPGSFSSLISGKINSFQFTIPSFVNLLVYFLIFYVSLLVVNKTTRVMLTLLKSLATVAILLLIVCMGVLLLLR

Steps to Reproduce (for bugs)

CONFIG: { "num_queries": 124, "use_templates": true, "num_relax": 0, "msa_mode": "mmseqs2_uniref_env", "model_type": "alphafold2_multimer_v3", "num_models": 5, "num_recycles": 3, "recycle_early_stop_tolerance": null, "num_ensemble": 1, "model_order": [ 1, 2, 3, 5 ], "keep_existing_results": true, "rank_by": "multimer", "max_seq": 512, "max_extra_seq": 1024, "pair_mode": "unpaired_paired", "host_url": "https://api.colabfold.com", "stop_at_score": 100, "random_seed": 0, "num_seeds": 1, "recompile_padding": 10, "commit": "b4c1bc7cf89bc0bd577c5a9d3c1f7bedc1f74152", "use_dropout": true, "use_cluster_profile": true, "use_fuse": true, "use_bfloat16": true, "version": "1.5.1" }

ColabFold Output (for bugs)

2023-02-07 15:41:10,513 Query 1/124: B6K772_SCHJYB6K772_SCHJY228aa (length 228) 2023-02-07 15:41:10,514 Sequence 0 found no templates 2023-02-07 15:41:10,518 Padding length to 238 Traceback (most recent call last): File "/home/ubuntu/colabfold_batch/colabfold-conda/bin/colabfold_batch", line 8, in sys.exit(main()) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1778, in main run( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1435, in run results = predict_structure( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 447, in predict_structure model_runner.predict(input_features, random_seed=seed, prediction_callback=callback) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 218, in predict result = self.apply(self.params, key, sub_feat) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/api.py", line 622, in cache_miss execute = dispatch._xla_call_impllazy(fun, *tracers, params) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy return xla_callable(fun, device, backend, name, donated_invars, keep_unused, File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun ans = call(fun, args) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached return lower_xla_callable(fun, device, backend, name, donated_invars, False, File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2 jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2 ans = fun.call_wrapped(intracers) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn out, state = f.apply(params, {}, *args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn out = f(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 98, in _forward_fn return model(batch, is_training=is_training) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, *kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, **kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 456, in call__ ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 441, in apply_network return impl( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, *kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 314, in call repr_shape = hk.eval_shape( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape out_shape = jax.eval_shape(stateless_fun, internal_state(), *args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/api.py", line 3201, in eval_shape out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 660, in abstract_evalfun , avalsout, = trace_to_jaxpr_dynamic( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(args, dict(self.params, *kwargs)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun out = fun(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 315, in lambda: embedding_module(batch, is_training)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 679, in call extra_evoformer_output, safe_key = extra_evoformer_stack( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 265, in wrapped ret = _LayerStackNoState( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, *kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, **kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 156, in call__ carry, zs = hk.scan( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan (carry, state), ys = jax.lax.scan( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback return fun(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 259, in scan init_flat, carry_avals, carry_avals_out, init_tree, rest = _create_jaxpr(init) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 245, in _create_jaxpr jaxpr, consts, out_tree = _initial_style_jaxpr( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_jaxpr jaxpr, consts, out_tree = _initial_style_open_jaxpr( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 54, in _initial_style_openjaxpr jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun carry, out = f(carry, x) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 149, in layer out_x, z = self._call_wrapped(carry.x, scanned.args_ys) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 182, in _call_wrapped ret = self._f(args) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 665, in extra_evoformer_fn extra_evoformer_output = extra_evoformer_iteration( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, *kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 1734, in call msa_act = dropout_wrapper_fn( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 76, in dropout_wrapper residual = module(input_act, mask, is_training=is_training, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 820, in call msa_act = mapping.inference_subbatch( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 223, in inference_subbatch return sharded_module(batched_args) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 145, in mapped_fn remainder_shape_dtype = hk.eval_shape( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape out_shape = jax.eval_shape(stateless_fun, internal_state(), args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/api.py", line 3201, in eval_shape out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 660, in abstract_evalfun , avalsout, = trace_to_jaxpr_dynamic( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper return func(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic ans = fun.call_wrapped(intracers) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped ans = self.f(args, dict(self.params, kwargs)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun out = fun(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 140, in apply_fun_to_slice input_slice = jax.tree_map( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/tree_util.py", line 207, in tree_map return treedef.unflatten(f(xs) for xs in zip(all_leaves)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/tree_util.py", line 207, in return treedef.unflatten(f(xs) for xs in zip(all_leaves)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 141, in lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 38, in _maybe_slice return jax.lax.dynamic_slice_in_dim( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 695, in dynamic_slice_in_dim return dynamic_slice(operand, start_indices, slice_sizes) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 109, in dynamic_slice return dynamic_slice_p.bind(operand, start_indices, dynamic_sizes, File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/core.py", line 329, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1712, in process_primitive return custom_staging_rules[primitive](self, *tracers, params) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 911, in _dynamic_slice_staging_rule return trace.default_process_primitive(dynamic_slice_p, (x, start_indices), File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1717, in default_process_primitive out_avals, effects = primitive.abstract_eval(avals, params) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/core.py", line 365, in abstracteval return abstract_eval(*args, *kwargs), no_effects File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 66, in standard_abstract_eval return core.ShapedArray(shape_rule(avals, **kwargs), File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 840, in _dynamic_slice_shape_rule raise TypeError(msg.format(slice_sizes, operand.shape)) jax._src.traceback_util.UnfilteredStackTrace: TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (4, 228, 64) for operand shape (0, 228, 64).

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/home/ubuntu/colabfold_batch/colabfold-conda/bin/colabfold_batch", line 8, in sys.exit(main()) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1778, in main run( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1435, in run results = predict_structure( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 447, in predict_structure model_runner.predict(input_features, random_seed=seed, prediction_callback=callback) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 218, in predict result = self.apply(self.params, key, sub_feat) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn out, state = f.apply(params, {}, *args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn out = f(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 98, in _forward_fn return model(batch, is_training=is_training) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 456, in call ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 441, in apply_network return impl( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 314, in call repr_shape = hk.eval_shape( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape out_shape = jax.eval_shape(stateless_fun, internal_state(), args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun out = fun(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 315, in lambda: embedding_module(batch, is_training)) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 679, in call extra_evoformer_output, safe_key = extra_evoformer_stack( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 265, in wrapped ret = _LayerStackNoState( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(*args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 156, in call carry, zs = hk.scan( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan (carry, state), ys = jax.lax.scan( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun carry, out = f(carry, x) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 149, in layer out_x, z = self._call_wrapped(carry.x, scanned.args_ys) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 182, in _call_wrapped ret = self._f(args) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 665, in extra_evoformer_fn extra_evoformer_output = extra_evoformer_iteration( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(*args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(args, kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 1734, in call msa_act = dropout_wrapper_fn( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 76, in dropout_wrapper residual = module(input_act, mask, is_training=is_training, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped out = f(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner return func(*args, *kwds) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors return bound_method(args, kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 820, in call msa_act = mapping.inference_subbatch( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 223, in inference_subbatch return sharded_module(batched_args) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 145, in mapped_fn remainder_shape_dtype = hk.eval_shape( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape out_shape = jax.eval_shape(stateless_fun, internal_state(), args, *kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun out = fun(args, **kwargs) File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 140, in apply_fun_to_slice input_slice = jax.tree_map( File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 141, in lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 38, in _maybe_slice return jax.lax.dynamic_slice_in_dim( TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (4, 228, 64) for operand shape (0, 228, 64).

CONTEXT

Running a series of jobs using a CSV of ids and sequences as input. Running on Ubuntu 20.04.5 LTS server with an A100 40GB GPU

sokrypton commented 1 year ago

Can you post the flags you used to run the job? (just want to make sure it wasn't a bug in how we parse the flags).

ernst-schmid commented 1 year ago

Here they are:

--num-recycle 3 --templates --use-dropout --model-order 1,2,3,5 --max-seq 512 --max-extra-seq 1024

sokrypton commented 1 year ago

This looks like an alphafold bug. For multimer, max-seq cannot be larger than 508, unless the MSA is larger than 512.

This is because the input MSA is padded to 512, and the first 508 sequences are given to the MSA input and the rest to the extra MSA input. If you set the max seq to larger than there are a number of sequences, then are no sequences left for extra MSA input.

sokrypton commented 1 year ago

I'll see if I can fix this... So users don't have to worry about remembering the internal cutoffs used 😅

ernst-schmid commented 1 year ago

Ah ok that makes sense thanks for your help with this!