Open ernst-schmid opened 1 year ago
Can you post the flags you used to run the job? (just want to make sure it wasn't a bug in how we parse the flags).
Here they are:
--num-recycle 3 --templates --use-dropout --model-order 1,2,3,5 --max-seq 512 --max-extra-seq 1024
This looks like an alphafold bug. For multimer, max-seq cannot be larger than 508, unless the MSA is larger than 512.
This is because the input MSA is padded to 512, and the first 508 sequences are given to the MSA input and the rest to the extra MSA input. If you set the max seq to larger than there are a number of sequences, then are no sequences left for extra MSA input.
I'll see if I can fix this... So users don't have to worry about remembering the internal cutoffs used 😅
Ah ok that makes sense thanks for your help with this!
Expected Behavior
Batch is able to run through all the queries in a CSV file
Current Behavior
Stops running at certain sequences that cause an internal issue.
input which caused failure:
MSPVPSVYDLLLRLVGHQEVLKDPAQMGKAIETLGKQIEEYKPGSFSSLISGKINSFQFTIPSFVNLLVYFLIFYVSLLVVNKTTRVMLTLLKSLATVAILLLIVCMGVLLLLR:MSPVPSVYDLLLRLVGHQEVLKDPAQMGKAIETLGKQIEEYKPGSFSSLISGKINSFQFTIPSFVNLLVYFLIFYVSLLVVNKTTRVMLTLLKSLATVAILLLIVCMGVLLLLR
Steps to Reproduce (for bugs)
CONFIG:
{ "num_queries": 124, "use_templates": true, "num_relax": 0, "msa_mode": "mmseqs2_uniref_env", "model_type": "alphafold2_multimer_v3", "num_models": 5, "num_recycles": 3, "recycle_early_stop_tolerance": null, "num_ensemble": 1, "model_order": [ 1, 2, 3, 5 ], "keep_existing_results": true, "rank_by": "multimer", "max_seq": 512, "max_extra_seq": 1024, "pair_mode": "unpaired_paired", "host_url": "https://api.colabfold.com", "stop_at_score": 100, "random_seed": 0, "num_seeds": 1, "recompile_padding": 10, "commit": "b4c1bc7cf89bc0bd577c5a9d3c1f7bedc1f74152", "use_dropout": true, "use_cluster_profile": true, "use_fuse": true, "use_bfloat16": true, "version": "1.5.1" }
ColabFold Output (for bugs)
2023-02-07 15:41:10,513 Query 1/124: B6K772_SCHJYB6K772_SCHJY228aa (length 228) 2023-02-07 15:41:10,514 Sequence 0 found no templates 2023-02-07 15:41:10,518 Padding length to 238 Traceback (most recent call last): File "/home/ubuntu/colabfold_batch/colabfold-conda/bin/colabfold_batch", line 8, in
sys.exit(main())
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1778, in main
run(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1435, in run
results = predict_structure(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 447, in predict_structure
model_runner.predict(input_features, random_seed=seed, prediction_callback=callback)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 218, in predict
result = self.apply(self.params, key, sub_feat)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/api.py", line 622, in cache_miss
execute = dispatch._xla_call_impllazy(fun, *tracers, params)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 303, in memoized_fun
ans = call(fun, args)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
ans = fun.call_wrapped(intracers)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 98, in _forward_fn
return model(batch, is_training=is_training)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, *kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(args, **kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 456, in call__
ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 441, in apply_network
return impl(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, *kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 314, in call
repr_shape = hk.eval_shape(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape
out_shape = jax.eval_shape(stateless_fun, internal_state(), *args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/api.py", line 3201, in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 660, in abstract_evalfun
, avalsout, = trace_to_jaxpr_dynamic(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(intracers)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(args, dict(self.params, *kwargs))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun
out = fun(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 315, in
lambda: embedding_module(batch, is_training))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f( args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 679, in call
extra_evoformer_output, safe_key = extra_evoformer_stack(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 265, in wrapped
ret = _LayerStackNoState(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, *kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(args, **kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 156, in call__
carry, zs = hk.scan(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan
(carry, state), ys = jax.lax.scan(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 259, in scan
init_flat, carry_avals, carry_avals_out, init_tree, rest = _create_jaxpr(init)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/loops.py", line 245, in _create_jaxpr
jaxpr, consts, out_tree = _initial_style_jaxpr(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 60, in _initial_style_jaxpr
jaxpr, consts, out_tree = _initial_style_open_jaxpr(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/control_flow/common.py", line 54, in _initial_style_openjaxpr
jaxpr, , consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, debug)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(intracers)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun
carry, out = f(carry, x)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 149, in layer
out_x, z = self._call_wrapped(carry.x, scanned.args_ys)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 182, in _call_wrapped
ret = self._f(args)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 665, in extra_evoformer_fn
extra_evoformer_output = extra_evoformer_iteration(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, *kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 1734, in call
msa_act = dropout_wrapper_fn(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 76, in dropout_wrapper
residual = module(input_act, mask, is_training=is_training, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(args, kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 820, in call
msa_act = mapping.inference_subbatch(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 223, in inference_subbatch
return sharded_module(batched_args)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 145, in mapped_fn
remainder_shape_dtype = hk.eval_shape(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape
out_shape = jax.eval_shape(stateless_fun, internal_state(), args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/api.py", line 3201, in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 660, in abstract_evalfun
, avalsout, = trace_to_jaxpr_dynamic(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1981, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(intracers)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(args, dict(self.params, kwargs))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun
out = fun(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 140, in apply_fun_to_slice
input_slice = jax.tree_map(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/tree_util.py", line 207, in tree_map
return treedef.unflatten(f(xs) for xs in zip(all_leaves))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/tree_util.py", line 207, in
return treedef.unflatten(f(xs) for xs in zip(all_leaves))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 141, in
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 38, in _maybe_slice
return jax.lax.dynamic_slice_in_dim(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 695, in dynamic_slice_in_dim
return dynamic_slice(operand, start_indices, slice_sizes)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 109, in dynamic_slice
return dynamic_slice_p.bind(operand, start_indices, dynamic_sizes,
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/core.py", line 329, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/core.py", line 332, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1712, in process_primitive
return custom_staging_rules[primitive](self, *tracers, params)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 911, in _dynamic_slice_staging_rule
return trace.default_process_primitive(dynamic_slice_p, (x, start_indices),
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1717, in default_process_primitive
out_avals, effects = primitive.abstract_eval(avals, params)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/core.py", line 365, in abstracteval
return abstract_eval(*args, *kwargs), no_effects
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/utils.py", line 66, in standard_abstract_eval
return core.ShapedArray(shape_rule(avals, **kwargs),
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/jax/_src/lax/slicing.py", line 840, in _dynamic_slice_shape_rule
raise TypeError(msg.format(slice_sizes, operand.shape))
jax._src.traceback_util.UnfilteredStackTrace: TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (4, 228, 64) for operand shape (0, 228, 64).
The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last): File "/home/ubuntu/colabfold_batch/colabfold-conda/bin/colabfold_batch", line 8, in
sys.exit(main())
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1778, in main
run(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 1435, in run
results = predict_structure(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/colabfold/batch.py", line 447, in predict_structure
model_runner.predict(input_features, random_seed=seed, prediction_callback=callback)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 218, in predict
result = self.apply(self.params, key, sub_feat)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 128, in apply_fn
out, state = f.apply(params, {}, *args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/transform.py", line 357, in apply_fn
out = f(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/model.py", line 98, in _forward_fn
return model(batch, is_training=is_training)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 456, in call
ret = apply_network(prev=batch.pop("prev"), safe_key=safe_key)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 441, in apply_network
return impl(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 314, in call
repr_shape = hk.eval_shape(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape
out_shape = jax.eval_shape(stateless_fun, internal_state(), args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun
out = fun(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 315, in
lambda: embedding_module(batch, is_training))
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(args, kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 679, in call
extra_evoformer_output, safe_key = extra_evoformer_stack(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 265, in wrapped
ret = _LayerStackNoState(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(args, kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(*args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 156, in call
carry, zs = hk.scan(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 640, in scan
(carry, state), ys = jax.lax.scan(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 623, in stateful_fun
carry, out = f(carry, x)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 149, in layer
out_x, z = self._call_wrapped(carry.x, scanned.args_ys)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/layer_stack.py", line 182, in _call_wrapped
ret = self._f(args)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules_multimer.py", line 665, in extra_evoformer_fn
extra_evoformer_output = extra_evoformer_iteration(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(*args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(args, kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 1734, in call
msa_act = dropout_wrapper_fn(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 76, in dropout_wrapper
residual = module(input_act, mask, is_training=is_training, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 426, in wrapped
out = f(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, *kwds)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/module.py", line 272, in run_interceptors
return bound_method(args, kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/modules.py", line 820, in call
msa_act = mapping.inference_subbatch(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 223, in inference_subbatch
return sharded_module(batched_args)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 145, in mapped_fn
remainder_shape_dtype = hk.eval_shape(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 951, in eval_shape
out_shape = jax.eval_shape(stateless_fun, internal_state(), args, *kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/haiku/_src/stateful.py", line 947, in stateless_fun
out = fun(args, **kwargs)
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 140, in apply_fun_to_slice
input_slice = jax.tree_map(
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 141, in
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
File "/home/ubuntu/colabfold_batch/colabfold-conda/lib/python3.8/site-packages/alphafold/model/mapping.py", line 38, in _maybe_slice
return jax.lax.dynamic_slice_in_dim(
TypeError: slice slice_sizes must be less than or equal to operand shape, got slice_sizes (4, 228, 64) for operand shape (0, 228, 64).
CONTEXT
Running a series of jobs using a CSV of ids and sequences as input. Running on Ubuntu 20.04.5 LTS server with an A100 40GB GPU