google-deepmind / alphafold

Open source code for AlphaFold 2.
Apache License 2.0
12.91k stars 2.29k forks source link

Weird Inference Behavior with GPU: Sometimes it works or Segmentation Fault/Filesystem Space Error #997

Open juliocesar-io opened 3 months ago

juliocesar-io commented 3 months ago

Hello all, I am encountering inconsistent behavior during GPU inference. Sometimes the inference runs successfully, but other times it fails with either:

or

Environment:

Steps to Reproduce:

Flags used:


'TF_FORCE_UNIFIED_MEMORY': '1',
'XLA_PYTHON_CLIENT_MEM_FRACTION': '4.0'
  1. Build the Docker image: build -f docker/Dockerfile -t alphafold

  2. Run the Docker container:

python3 docker/run_docker.py \
  --fasta_paths=/home/juliocesar/DataZ/Models/WP_277336079.1.fasta \
  --max_template_date=2020-05-14 \
  --model_preset=monomer \
  --data_dir=/home/juliocesar/DataZ/Models/alphafold_all_data \
  --output_dir=/home/juliocesar/DataZ/Models/output \
  --enable_gpu_relax=false
  1. The run may succeed and only use about 2GB of vRAM, and the results look fine.
  2. If I run another inference, I encounter either:

Segmentation Fault:

   I0802 00:55:45.666194 127875129517888 run_docker.py:262] Fatal Python error: Segmentation fault
   I0802 00:55:45.666548 127875129517888 run_docker.py:262]
   I0802 00:55:45.666624 127875129517888 run_docker.py:262] Thread 0x000070bc94f6b280 (most recent call first):
   I0802 00:55:45.666688 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/compiler.py", line 238 in backend_compile
   I0802 00:55:45.666738 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/profiler.py", line 335 in wrapper
   I0802 00:55:45.667106 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/compiler.py", line 500 in _compile_and_write_cache
   I0802 00:55:45.667136 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/compiler.py", line 333 in compile_or_get_cached
   I0802 00:55:45.667161 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2718 in _cached_compilation
   I0802 00:55:45.667187 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2908 in from_hlo
   I0802 00:55:45.667212 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2369 in compile
   I0802 00:55:45.667233 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 1406 in _pjit_call_impl_python
   I0802 00:55:45.667258 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 1471 in call_impl_cache_miss
   I0802 00:55:45.667283 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 1488 in _pjit_call_impl
   I0802 00:55:45.667304 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/core.py", line 913 in process_primitive
   I0802 00:55:45.667324 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/core.py", line 425 in bind_with_trace
   I0802 00:55:45.667344 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/core.py", line 2788 in bind
   I0802 00:55:45.667364 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 176 in _python_pjit_helper
   I0802 00:55:45.667383 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 298 in cache_miss
   I0802 00:55:45.667402 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179 in reraise_with_filtered_traceback
   I0802 00:55:45.667421 127875129517888 run_docker.py:262] File "/app/alphafold/alphafold/model/model.py", line 167 in predict
   I0802 00:55:45.667440 127875129517888 run_docker.py:262] File "/app/alphafold/run_alphafold.py", line 284 in predict_structure
   I0802 00:55:45.667459 127875129517888 run_docker.py:262] File "/app/alphafold/run_alphafold.py", line 543 in main
   I0802 00:55:45.667478 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/absl/app.py", line 258 in _run_main
   I0802 00:55:45.667497 127875129517888 run_docker.py:262] File "/opt/conda/lib/python3.11/site-packages/absl/app.py", line 312 in run
   I0802 00:55:45.667516 127875129517888 run_docker.py:262] File "/app/alphafold/run_alphafold.py", line 570 in <module>

Fatal Python error: Aborted: If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided

I0802 01:34:50.745555 132592778102592 run_docker.py:263] 2024-08-02 01:34:50.745063: F external/xla/xla/service/gpu/gemm_fusion_autotuner.cc:780] Non-OK-status: has_executable.status() status: INTERNAL: ptxas exited with non-zero error code 139, output: : If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.Failure occured when compiling fusion gemm_fusion_dot.52354 with config '{block_m:16,block_n:16,block_k:256,split_k:1,num_stages:1,num_warps:4,num_ctas:1}'
I0802 01:34:50.745778 132592778102592 run_docker.py:263] Fused HLO computation:
I0802 01:34:50.745835 132592778102592 run_docker.py:263] %gemm_fusion_dot.52354_computation (parameter_0.92: f32[17,384], parameter_1.92: f32[384], parameter_2.28: f32[384,384]) -> f32[17,384] {
I0802 01:34:50.745885 132592778102592 run_docker.py:263] %parameter_0.92 = f32[17,384]{1,0} parameter(0)
I0802 01:34:50.745933 132592778102592 run_docker.py:263] %parameter_1.92 = f32[384]{0} parameter(1)
I0802 01:34:50.745979 132592778102592 run_docker.py:263] %broadcast.15023 = f32[17,384]{1,0} broadcast(f32[384]{0} %parameter_1.92), dimensions={1}, metadata={op_name="jit(apply_fn)/jit(main)/alphafold/alphafold_iteration/structure_module/single_layer_norm/single_layer_norm/add" source_file="/app/alphafold/alphafold/model/common_modules.py" source_line=185}
I0802 01:34:50.746032 132592778102592 run_docker.py:263] %add.12065 = f32[17,384]{1,0} add(f32[17,384]{1,0} %parameter_0.92, f32[17,384]{1,0} %broadcast.15023), metadata={op_name="jit(apply_fn)/jit(main)/alphafold/alphafold_iteration/structure_module/single_layer_norm/single_layer_norm/add" source_file="/app/alphafold/alphafold/model/common_modules.py" source_line=185}
I0802 01:34:50.746080 132592778102592 run_docker.py:263] %parameter_2.28 = f32[384,384]{1,0} parameter(2)
I0802 01:34:50.746122 132592778102592 run_docker.py:263] ROOT %dot.3542 = f32[17,384]{1,0} dot(f32[17,384]{1,0} %add.12065, f32[384,384]{1,0} %parameter_2.28), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="jit(apply_fn)/jit(main)/alphafold/alphafold_iteration/structure_module/initial_projection/...a, ah->...h/dot_general[dimension_numbers=(((1,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/app/alphafold/alphafold/model/common_modules.py" source_line=122}
I0802 01:34:50.746166 132592778102592 run_docker.py:263] }
I0802 01:34:50.746207 132592778102592 run_docker.py:263] Fatal Python error: Aborted
I0802 01:34:50.746250 132592778102592 run_docker.py:263]
I0802 01:34:50.746290 132592778102592 run_docker.py:263] Thread 0x00007874dcb35280 (most recent call first):
I0802 01:34:50.746330 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/compiler.py", line 238 in backend_compile
I0802 01:34:50.746370 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/profiler.py", line 335 in wrapper
I0802 01:34:50.746411 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/compiler.py", line 500 in _compile_and_write_cache
I0802 01:34:50.746460 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/compiler.py", line 333 in compile_or_get_cached
I0802 01:34:50.746501 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2718 in _cached_compilation
I0802 01:34:50.746541 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2908 in from_hlo
I0802 01:34:50.746581 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2369 in compile
I0802 01:34:50.746620 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 1406 in _pjit_call_impl_python
I0802 01:34:50.746675 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 1471 in call_impl_cache_miss
I0802 01:34:50.746716 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 1488 in _pjit_call_impl
I0802 01:34:50.746762 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/core.py", line 913 in process_primitive
I0802 01:34:50.747026 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/core.py", line 425 in bind_with_trace
I0802 01:34:50.747194 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/core.py", line 2788 in bind
I0802 01:34:50.747266 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 176 in _python_pjit_helper
I0802 01:34:50.747329 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/pjit.py", line 298 in cache_miss
I0802 01:34:50.747386 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179 in reraise_with_filtered_traceback
I0802 01:34:50.747431 132592778102592 run_docker.py:263] File "/app/alphafold/alphafold/model/model.py", line 167 in predict
I0802 01:34:50.747478 132592778102592 run_docker.py:263] File "/app/alphafold/run_alphafold.py", line 284 in predict_structure
I0802 01:34:50.747540 132592778102592 run_docker.py:263] File "/app/alphafold/run_alphafold.py", line 543 in main
I0802 01:34:50.747584 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/absl/app.py", line 258 in _run_main
I0802 01:34:50.747641 132592778102592 run_docker.py:263] File "/opt/conda/lib/python3.11/site-packages/absl/app.py", line 312 in run
I0802 01:34:50.747692 132592778102592 run_docker.py:263] File "/app/alphafold/run_alphafold.py", line 570 in <module>

Successful inference:

I0802 12:20:21.280953 130830089832256 run_docker.py:262] I0802 12:20:21.280649 126517922906752 run_alphafold.py:276] Running model model_1_pred_0 on WP_277325438.1
I0802 12:20:22.182663 130830089832256 run_docker.py:262] 2024-08-02 12:20:22.182248: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:388] MLIR V1 optimization pass is not enabled
I0802 12:20:22.435071 130830089832256 run_docker.py:262] I0802 12:20:22.434748 126517922906752 model.py:165] Running predict with shape(feat) = {'aatype': (4, 57), 'residue_index': (4, 57), 'seq_length': (4,), 'template_aatype': (4, 4, 57), 'template_all_atom_masks': (4, 4, 57, 37), 'template_all_atom_positions': (4, 4, 57, 37, 3), 'template_sum_probs': (4, 4, 1), 'is_distillation': (4,), 'seq_mask': (4, 57), 'msa_mask': (4, 508, 57), 'msa_row_mask': (4, 508), 'random_crop_to_size_seed': (4, 2), 'template_mask': (4, 4), 'template_pseudo_beta': (4, 4, 57, 3), 'template_pseudo_beta_mask': (4, 4, 57), 'atom14_atom_exists': (4, 57, 14), 'residx_atom14_to_atom37': (4, 57, 14), 'residx_atom37_to_atom14': (4, 57, 37), 'atom37_atom_exists': (4, 57, 37), 'extra_msa': (4, 5120, 57), 'extra_msa_mask': (4, 5120, 57), 'extra_msa_row_mask': (4, 5120), 'bert_mask': (4, 508, 57), 'true_msa': (4, 508, 57), 'extra_has_deletion': (4, 5120, 57), 'extra_deletion_value': (4, 5120, 57), 'msa_feat': (4, 508, 57, 49), 'target_feat': (4, 57, 22)}
I0802 12:21:06.842577 130830089832256 run_docker.py:262] I0802 12:21:06.842006 126517922906752 model.py:175] Output shape was {'distogram': {'bin_edges': (63,), 'logits': (57, 57, 64)}, 'experimentally_resolved': {'logits': (57, 37)}, 'masked_msa': {'logits': (508, 57, 23)}, 'predicted_lddt': {'logits': (57, 50)}, 'structure_module': {'final_atom_mask': (57, 37), 'final_atom_positions': (57, 37, 3)}, 'plddt': (57,), 'ranking_confidence': ()}
I0802 12:21:06.842761 130830089832256 run_docker.py:262] I0802 12:21:06.842111 126517922906752 run_alphafold.py:288] Total JAX model model_1_pred_0 on WP_277325438.1 predict time (includes compilation time, see --benchmark): 44.4s
I0802 12:21:06.867249 130830089832256 run_docker.py:262] I0802 12:21:06.866905 126517922906752 run_alphafold.py:276] Running model model_2_pred_0 on WP_277325438.1
I0802 12:21:07.528902 130830089832256 run_docker.py:262] I0802 12:21:07.528476 126517922906752 model.py:165] Running predict with shape(feat) = {'aatype': (4, 57), 'residue_index': (4, 57), 'seq_length': (4,), 'template_aatype': (4, 4, 57), 'template_all_atom_masks': (4, 4, 57, 37), 'template_all_atom_positions': (4, 4, 57, 37, 3), 'template_sum_probs': (4, 4, 1), 'is_distillation': (4,), 'seq_mask': (4, 57), 'msa_mask': (4, 508, 57), 'msa_row_mask': (4, 508), 'random_crop_to_size_seed': (4, 2), 'template_mask': (4, 4), 'template_pseudo_beta': (4, 4, 57, 3), 'template_pseudo_beta_mask': (4, 4, 57), 'atom14_atom_exists': (4, 57, 14), 'residx_atom14_to_atom37': (4, 57, 14), 'residx_atom37_to_atom14': (4, 57, 37), 'atom37_atom_exists': (4, 57, 37), 'extra_msa': (4, 1024, 57), 'extra_msa_mask': (4, 1024, 57), 'extra_msa_row_mask': (4, 1024), 'bert_mask': (4, 508, 57), 'true_msa': (4, 508, 57), 'extra_has_deletion': (4, 1024, 57), 'extra_deletion_value': (4, 1024, 57), 'msa_feat': (4, 508, 57, 49), 'target_feat': (4, 57, 22)}
I0802 12:21:45.562296 130830089832256 run_docker.py:262] I0802 12:21:45.561828 126517922906752 model.py:175] Output shape was {'distogram': {'bin_edges': (63,), 'logits': (57, 57, 64)}, 'experimentally_resolved': {'logits': (57, 37)}, 'masked_msa': {'logits': (508, 57, 23)}, 'predicted_lddt': {'logits': (57, 50)}, 'structure_module': {'final_atom_mask': (57, 37), 'final_atom_positions': (57, 37, 3)}, 'plddt': (57,), 'ranking_confidence': ()}
I0802 12:21:45.562419 130830089832256 run_docker.py:262] I0802 12:21:45.561898 126517922906752 run_alphafold.py:288] Total JAX model model_2_pred_0 on WP_277325438.1 predict time (includes compilation time, see --benchmark): 38.0s
I0802 12:21:45.576066 130830089832256 run_docker.py:262] I0802 12:21:45.575660 126517922906752 run_alphafold.py:276] Running model model_3_pred_0 on WP_277325438.1
I0802 12:21:46.514541 130830089832256 run_docker.py:262] I0802 12:21:46.514255 126517922906752 model.py:165] Running predict with shape(feat) = {'aatype': (4, 57), 'residue_index': (4, 57), 'seq_length': (4,), 'is_distillation': (4,), 'seq_mask': (4, 57), 'msa_mask': (4, 512, 57), 'msa_row_mask': (4, 512), 'random_crop_to_size_seed': (4, 2), 'atom14_atom_exists': (4, 57, 14), 'residx_atom14_to_atom37': (4, 57, 14), 'residx_atom37_to_atom14': (4, 57, 37), 'atom37_atom_exists': (4, 57, 37), 'extra_msa': (4, 5120, 57), 'extra_msa_mask': (4, 5120, 57), 'extra_msa_row_mask': (4, 5120), 'bert_mask': (4, 512, 57), 'true_msa': (4, 512, 57), 'extra_has_deletion': (4, 5120, 57), 'extra_deletion_value': (4, 5120, 57), 'msa_feat': (4, 512, 57, 49), 'target_feat': (4, 57, 22)}
I0802 12:22:18.854233 130830089832256 run_docker.py:262] I0802 12:22:18.853717 126517922906752 model.py:175] Output shape was {'distogram': {'bin_edges': (63,), 'logits': (57, 57, 64)}, 'experimentally_resolved': {'logits': (57, 37)}, 'masked_msa': {'logits': (512, 57, 23)}, 'predicted_lddt': {'logits': (57, 50)}, 'structure_module': {'final_atom_mask': (57, 37), 'final_atom_positions': (57, 37, 3)}, 'plddt': (57,), 'ranking_confidence': ()}
I0802 12:22:18.854382 130830089832256 run_docker.py:262] I0802 12:22:18.853795 126517922906752 run_alphafold.py:288] Total JAX model model_3_pred_0 on WP_277325438.1 predict time (includes compilation time, see --benchmark): 32.3s
I0802 12:22:18.868918 130830089832256 run_docker.py:262] I0802 12:22:18.868620 126517922906752 run_alphafold.py:276] Running model model_4_pred_0 on WP_277325438.1
I0802 12:22:19.415913 130830089832256 run_docker.py:262] I0802 12:22:19.415531 126517922906752 model.py:165] Running predict with shape(feat) = {'aatype': (4, 57), 'residue_index': (4, 57), 'seq_length': (4,), 'is_distillation': (4,), 'seq_mask': (4, 57), 'msa_mask': (4, 512, 57), 'msa_row_mask': (4, 512), 'random_crop_to_size_seed': (4, 2), 'atom14_atom_exists': (4, 57, 14), 'residx_atom14_to_atom37': (4, 57, 14), 'residx_atom37_to_atom14': (4, 57, 37), 'atom37_atom_exists': (4, 57, 37), 'extra_msa': (4, 5120, 57), 'extra_msa_mask': (4, 5120, 57), 'extra_msa_row_mask': (4, 5120), 'bert_mask': (4, 512, 57), 'true_msa': (4, 512, 57), 'extra_has_deletion': (4, 5120, 57), 'extra_deletion_value': (4, 5120, 57), 'msa_feat': (4, 512, 57, 49), 'target_feat': (4, 57, 22)}
I0802 12:22:49.981101 130830089832256 run_docker.py:262] I0802 12:22:49.980713 126517922906752 model.py:175] Output shape was {'distogram': {'bin_edges': (63,), 'logits': (57, 57, 64)}, 'experimentally_resolved': {'logits': (57, 37)}, 'masked_msa': {'logits': (512, 57, 23)}, 'predicted_lddt': {'logits': (57, 50)}, 'structure_module': {'final_atom_mask': (57, 37), 'final_atom_positions': (57, 37, 3)}, 'plddt': (57,), 'ranking_confidence': ()}
I0802 12:22:49.981184 130830089832256 run_docker.py:262] I0802 12:22:49.980788 126517922906752 run_alphafold.py:288] Total JAX model model_4_pred_0 on WP_277325438.1 predict time (includes compilation time, see --benchmark): 30.6s
I0802 12:22:49.994777 130830089832256 run_docker.py:262] I0802 12:22:49.994589 126517922906752 run_alphafold.py:276] Running model model_5_pred_0 on WP_277325438.1
I0802 12:22:50.547898 130830089832256 run_docker.py:262] I0802 12:22:50.547599 126517922906752 model.py:165] Running predict with shape(feat) = {'aatype': (4, 57), 'residue_index': (4, 57), 'seq_length': (4,), 'is_distillation': (4,), 'seq_mask': (4, 57), 'msa_mask': (4, 512, 57), 'msa_row_mask': (4, 512), 'random_crop_to_size_seed': (4, 2), 'atom14_atom_exists': (4, 57, 14), 'residx_atom14_to_atom37': (4, 57, 14), 'residx_atom37_to_atom14': (4, 57, 37), 'atom37_atom_exists': (4, 57, 37), 'extra_msa': (4, 1024, 57), 'extra_msa_mask': (4, 1024, 57), 'extra_msa_row_mask': (4, 1024), 'bert_mask': (4, 512, 57), 'true_msa': (4, 512, 57), 'extra_has_deletion': (4, 1024, 57), 'extra_deletion_value': (4, 1024, 57), 'msa_feat': (4, 512, 57, 49), 'target_feat': (4, 57, 22)}
I0802 12:23:20.811322 130830089832256 run_docker.py:262] I0802 12:23:20.810725 126517922906752 model.py:175] Output shape was {'distogram': {'bin_edges': (63,), 'logits': (57, 57, 64)}, 'experimentally_resolved': {'logits': (57, 37)}, 'masked_msa': {'logits': (512, 57, 23)}, 'predicted_lddt': {'logits': (57, 50)}, 'structure_module': {'final_atom_mask': (57, 37), 'final_atom_positions': (57, 37, 3)}, 'plddt': (57,), 'ranking_confidence': ()}
I0802 12:23:20.811408 130830089832256 run_docker.py:262] I0802 12:23:20.810804 126517922906752 run_alphafold.py:288] Total JAX model model_5_pred_0 on WP_277325438.1 predict time (includes compilation time, see --benchmark): 30.3s
I0802 12:23:21.161775 130830089832256 run_docker.py:262] I0802 12:23:21.161459 126517922906752 amber_minimize.py:178] alterations info: {'nonstandard_residues': [], 'removed_heterogens': set(), 'missing_residues': {}, 'missing_heavy_atoms': {}, 'missing_terminals': {<Residue 56 (GLY) of chain 0>: ['OXT']}, 'Se_in_MET': [], 'removed_chains': {0: []}}
I0802 12:23:21.182899 130830089832256 run_docker.py:262] I0802 12:23:21.182727 126517922906752 amber_minimize.py:408] Minimizing protein, attempt 1 of 100.
I0802 12:23:21.218353 130830089832256 run_docker.py:262] I0802 12:23:21.217987 126517922906752 amber_minimize.py:69] Restraining 384 / 761 particles.
I0802 12:23:21.510061 130830089832256 run_docker.py:262] I0802 12:23:21.509803 126517922906752 amber_minimize.py:178] alterations info: {'nonstandard_residues': [], 'removed_heterogens': set(), 'missing_residues': {}, 'missing_heavy_atoms': {}, 'missing_terminals': {}, 'Se_in_MET': [], 'removed_chains': {0: []}}
I0802 12:23:23.018579 130830089832256 run_docker.py:262] I0802 12:23:23.018234 126517922906752 amber_minimize.py:500] Iteration completed: Einit 25378.80 Efinal -497.05 Time 0.13 s num residue violations 0 num residue exclusions 0
I0802 12:23:23.084537 130830089832256 run_docker.py:262] I0802 12:23:23.084163 126517922906752 run_alphafold.py:414] Final timings for WP_277325438.1: {'features': 9.396395921707153, 'process_features_model_1_pred_0': 1.1538565158843994, 'predict_and_compile_model_1_pred_0': 44.40755605697632, 'process_features_model_2_pred_0': 0.6614584922790527, 'predict_and_compile_model_2_pred_0': 38.0334849357605, 'process_features_model_3_pred_0': 0.938460111618042, 'predict_and_compile_model_3_pred_0': 32.33959984779358, 'process_features_model_4_pred_0': 0.5467894077301025, 'predict_and_compile_model_4_pred_0': 30.56532311439514, 'process_features_model_5_pred_0': 0.5529048442840576, 'predict_and_compile_model_5_pred_0': 30.263261079788208, 'relax_model_1_pred_0': 2.210644483566284}

Expected Behavior:

Troubleshooting Steps Taken:

  1. Verified sufficient GPU memory is available.
  2. Checked Docker container for sufficient filesystem space.
  3. Monitored system memory usage to ensure it does not exceed limits.
  4. The problem persists even with reduced batch sizes on the model/config.py file
  5. Installed the model in the local machine using a Conda env to see if Docker was the issue but no luck, same behaviour.
  6. Tried using these flags, which has a effect on the amount of vRAM used but without apparent effect on the inference inconsistency:
'TF_FORCE_GPU_ALLOW_GROWTH': 'true',
'XLA_PYTHON_CLIENT_PREALLOCATE': 'false',
'XLA_PYTHON_CLIENT_MEM_FRACTION': '0.5',

Any guidance or suggestions for resolving these issues would be greatly appreciated.