Zuricho / ParallelFold

Modified version of Alphafold to divide CPU part (MSA and template searching) and GPU part. This can accelerate Alphafold when predicting multiple structures
133 stars 45 forks source link

Parafold run failing since pulling latest changes #27

Open gauravdiwan89 opened 1 year ago

gauravdiwan89 commented 1 year ago


I pulled the latest Parafold changes and created a new environment with the suggested installation steps. Next I ran the following command to use Alphafold.

(parafold)[ParallelFold]$ ./run_alphafold.sh \
-d ../alphafold_data \
-o ../alphafold_output/ \
-m model_1,model_2,model_3,model_4,model_5 \
-p monomer \
-i ../alphafold_input/IFT57.fasta \
-t 1800-01-01 \
-g true \
-u all

Unfortunately I get the following error

I0228 13:58:03.541854 22820960240128 templates.py:857] Using precomputed obsolete pdbs ../alphafold_data/pdb_mmcif/obsolete.dat.
I0228 13:58:03.663074 22820960240128 xla_bridge.py:353] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0228 13:58:04.104593 22820960240128 xla_bridge.py:353] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
I0228 13:58:04.104862 22820960240128 xla_bridge.py:353] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0228 13:58:04.104926 22820960240128 xla_bridge.py:353] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
./run_alphafold.sh: line 244: 3718078 Killed                  python $alphafold_script --fasta_paths=$fasta_path --model_names=$model_selection --parameter_path=$parameter_path --output_dir=$output_dir --jackhmmer_binary_path=$jackhmmer_binary_path --hhblits_binary_path=$hhblits_binary_path --hhsearch_binary_path=$hhsearch_binary_path --hmmsearch_binary_path=$hmmsearch_binary_path --hmmbuild_binary_path=$hmmbuild_binary_path --kalign_binary_path=$kalign_binary_path --uniref90_database_path=$uniref90_database_path --mgnify_database_path=$mgnify_database_path --bfd_database_path=$bfd_database_path --small_bfd_database_path=$small_bfd_database_path --uniref30_database_path=$uniref30_database_path --uniprot_database_path=$uniprot_database_path --pdb70_database_path=$pdb70_database_path --pdb_seqres_database_path=$pdb_seqres_database_path --template_mmcif_dir=$template_mmcif_dir --max_template_date=$max_template_date --obsolete_pdbs_path=$obsolete_pdbs_path --db_preset=$db_preset --model_preset=$model_preset --benchmark=$benchmark --models_to_relax=$models_to_relax --use_gpu_relax=$use_gpu_relax --recycling=$recycling --run_feature=$run_feature --logtostderr

I tried searching for the error elsewhere and some suggested that my jax/jaxlib versions may not be compatible for the CUDA and cudnn version running on my machines. However, I checked this and the versions are all correct since running jax.devices() in python detects my GPU.

So I am puzzled why the software is not running any longer. Can you please help me with this?

I was able to successfully run alphafold before the latest changes (with version 2.2).

Zuricho commented 1 year ago

Could you send me your cuda version and jax/jaxlib version? I think you are correct that "jax/jaxlib versions may not be compatible for the CUDA"

gauravdiwan89 commented 1 year ago

CUDA version 11.6 jax 0.3.25 jaxlib 0.3.25+cuda11.cudnn82

Zuricho commented 1 year ago

My environment is similar to you:

gauravdiwan89 commented 1 year ago

I see, then it must be something else. I tried again with the latest version of cudatoolkit (11.8) and cudnn (8.4.1), but it still fails.

I am also running the program on a HPC where CUDA and cudnn are loaded as modules and are not in the standard path such as /usr/local/. Do you think this maybe a reason why it fails?

Zuricho commented 1 year ago

Maybe there are some other problems. The standard path should not be the reason.

I have a suggestion. Can you try to use CPU to run this pipeline? Another thing is to double check whether sufficient memory is provided.

gauravdiwan89 commented 1 year ago

Unfortunately that does not work either. I get the following error

I0302 12:01:53.762366 22486551396480 templates.py:857] Using precomputed obsolete pdbs ../alphafold_data/pdb_mmcif/obsolete.dat.
I0302 12:01:54.058316 22486551396480 xla_bridge.py:353] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
2023-03-02 12:01:54.275184: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
I0302 12:01:54.275487 22486551396480 xla_bridge.py:353] Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices.
I0302 12:01:54.275785 22486551396480 xla_bridge.py:353] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter Host CUDA
I0302 12:01:54.276108 22486551396480 xla_bridge.py:353] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0302 12:01:54.276158 22486551396480 xla_bridge.py:353] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W0302 12:01:54.276229 22486551396480 xla_bridge.py:360] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
./run_alphafold.sh: line 244: 1077277 Killed                  python $alphafold_script --fasta_paths=$fasta_path --model_names=$model_selection --parameter_path=$parameter_path --output_dir=$output_dir --jackhmmer_binary_path=$jackhmmer_binary_path --hhblits_binary_path=$hhblits_binary_path --hhsearch_binary_path=$hhsearch_binary_path --hmmsearch_binary_path=$hmmsearch_binary_path --hmmbuild_binary_path=$hmmbuild_binary_path --kalign_binary_path=$kalign_binary_path --uniref90_database_path=$uniref90_database_path --mgnify_database_path=$mgnify_database_path --bfd_database_path=$bfd_database_path --small_bfd_database_path=$small_bfd_database_path --uniref30_database_path=$uniref30_database_path --uniprot_database_path=$uniprot_database_path --pdb70_database_path=$pdb70_database_path --pdb_seqres_database_path=$pdb_seqres_database_path --template_mmcif_dir=$template_mmcif_dir --max_template_date=$max_template_date --obsolete_pdbs_path=$obsolete_pdbs_path --db_preset=$db_preset --model_preset=$model_preset --benchmark=$benchmark --models_to_relax=$models_to_relax --use_gpu_relax=$use_gpu_relax --recycling=$recycling --run_feature=$run_feature --logtostderr

I think I also have 256GB of memory

gauravdiwan89 commented 1 year ago

I seemed to have solved the issue with the version of jax and jaxlib. Now I do not get the rocm and plugin errors. But the run still gets killed at line 244 of ./run_alphafold.sh. I will now try and check if any of the arguments for the command are problematic.

gauravdiwan89 commented 1 year ago

I was finally only able to run the python script run_alphafold.py using the following parameters

python run_alphafold.py \
--fasta_paths=../alphafold_input/IFT57.fasta \
--output_dir=../alphafold_output \
--parameter_path=../alphafold_data/params/  \
--uniref90_database_path=../alphafold_data/uniref90/uniref90.fasta \
--mgnify_database_path=../alphafold_data/mgnify/mgy_clusters_2018_12.fa \
--template_mmcif_dir=../alphafold_data/pdb_mmcif/mmcif_files/ \
--obsolete_pdbs_path=../alphafold_data/pdb_mmcif/obsolete.dat \
--bfd_database_path=../alphafold_data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--uniref30_database_path=../alphafold_data/uniclust30/uniclust30_2020_06/UniRef30_2020_06 \
--pdb70_database_path=../alphafold_data/pdb70/pdb70 \
--max_template_date='1800-01-01' \

I don't know where the error is coming from when I run the bash script. The environment variables seem fine.

My jax and jaxlib versions are the latest - 0.4.4 and 0.4.4+cuda11.cudnn86 respectively