Zuricho / ParallelFold

Modified version of Alphafold to divide CPU part (MSA and template searching) and GPU part. This can accelerate Alphafold when predicting multiple structures
https://parafold.sjtu.edu.cn
133 stars 45 forks source link

Parafold run failing since pulling latest changes #27

Open gauravdiwan89 opened 1 year ago

gauravdiwan89 commented 1 year ago

Hello.

I pulled the latest Parafold changes and created a new environment with the suggested installation steps. Next I ran the following command to use Alphafold.

(parafold)[ParallelFold]$ ./run_alphafold.sh \
-d ../alphafold_data \
-o ../alphafold_output/ \
-m model_1,model_2,model_3,model_4,model_5 \
-p monomer \
-i ../alphafold_input/IFT57.fasta \
-t 1800-01-01 \
-g true \
-u all

Unfortunately I get the following error

I0228 13:58:03.541854 22820960240128 templates.py:857] Using precomputed obsolete pdbs ../alphafold_data/pdb_mmcif/obsolete.dat.
I0228 13:58:03.663074 22820960240128 xla_bridge.py:353] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I0228 13:58:04.104593 22820960240128 xla_bridge.py:353] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA Host
I0228 13:58:04.104862 22820960240128 xla_bridge.py:353] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0228 13:58:04.104926 22820960240128 xla_bridge.py:353] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
./run_alphafold.sh: line 244: 3718078 Killed                  python $alphafold_script --fasta_paths=$fasta_path --model_names=$model_selection --parameter_path=$parameter_path --output_dir=$output_dir --jackhmmer_binary_path=$jackhmmer_binary_path --hhblits_binary_path=$hhblits_binary_path --hhsearch_binary_path=$hhsearch_binary_path --hmmsearch_binary_path=$hmmsearch_binary_path --hmmbuild_binary_path=$hmmbuild_binary_path --kalign_binary_path=$kalign_binary_path --uniref90_database_path=$uniref90_database_path --mgnify_database_path=$mgnify_database_path --bfd_database_path=$bfd_database_path --small_bfd_database_path=$small_bfd_database_path --uniref30_database_path=$uniref30_database_path --uniprot_database_path=$uniprot_database_path --pdb70_database_path=$pdb70_database_path --pdb_seqres_database_path=$pdb_seqres_database_path --template_mmcif_dir=$template_mmcif_dir --max_template_date=$max_template_date --obsolete_pdbs_path=$obsolete_pdbs_path --db_preset=$db_preset --model_preset=$model_preset --benchmark=$benchmark --models_to_relax=$models_to_relax --use_gpu_relax=$use_gpu_relax --recycling=$recycling --run_feature=$run_feature --logtostderr

I tried searching for the error elsewhere and some suggested that my jax/jaxlib versions may not be compatible for the CUDA and cudnn version running on my machines. However, I checked this and the versions are all correct since running jax.devices() in python detects my GPU.

So I am puzzled why the software is not running any longer. Can you please help me with this?

I was able to successfully run alphafold before the latest changes (with version 2.2).

Zuricho commented 1 year ago

Could you send me your cuda version and jax/jaxlib version? I think you are correct that "jax/jaxlib versions may not be compatible for the CUDA"

gauravdiwan89 commented 1 year ago

CUDA version 11.6 jax 0.3.25 jaxlib 0.3.25+cuda11.cudnn82

Zuricho commented 1 year ago

My environment is similar to you:

gauravdiwan89 commented 1 year ago

I see, then it must be something else. I tried again with the latest version of cudatoolkit (11.8) and cudnn (8.4.1), but it still fails.

I am also running the program on a HPC where CUDA and cudnn are loaded as modules and are not in the standard path such as /usr/local/. Do you think this maybe a reason why it fails?

Zuricho commented 1 year ago

Maybe there are some other problems. The standard path should not be the reason.

I have a suggestion. Can you try to use CPU to run this pipeline? Another thing is to double check whether sufficient memory is provided.

gauravdiwan89 commented 1 year ago

Unfortunately that does not work either. I get the following error

I0302 12:01:53.762366 22486551396480 templates.py:857] Using precomputed obsolete pdbs ../alphafold_data/pdb_mmcif/obsolete.dat.
I0302 12:01:54.058316 22486551396480 xla_bridge.py:353] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
2023-03-02 12:01:54.275184: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
I0302 12:01:54.275487 22486551396480 xla_bridge.py:353] Unable to initialize backend 'cuda': FAILED_PRECONDITION: No visible GPU devices.
I0302 12:01:54.275785 22486551396480 xla_bridge.py:353] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter Host CUDA
I0302 12:01:54.276108 22486551396480 xla_bridge.py:353] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
I0302 12:01:54.276158 22486551396480 xla_bridge.py:353] Unable to initialize backend 'plugin': xla_extension has no attributes named get_plugin_device_client. Compile TensorFlow with //tensorflow/compiler/xla/python:enable_plugin_device set to true (defaults to false) to enable this.
W0302 12:01:54.276229 22486551396480 xla_bridge.py:360] No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
./run_alphafold.sh: line 244: 1077277 Killed                  python $alphafold_script --fasta_paths=$fasta_path --model_names=$model_selection --parameter_path=$parameter_path --output_dir=$output_dir --jackhmmer_binary_path=$jackhmmer_binary_path --hhblits_binary_path=$hhblits_binary_path --hhsearch_binary_path=$hhsearch_binary_path --hmmsearch_binary_path=$hmmsearch_binary_path --hmmbuild_binary_path=$hmmbuild_binary_path --kalign_binary_path=$kalign_binary_path --uniref90_database_path=$uniref90_database_path --mgnify_database_path=$mgnify_database_path --bfd_database_path=$bfd_database_path --small_bfd_database_path=$small_bfd_database_path --uniref30_database_path=$uniref30_database_path --uniprot_database_path=$uniprot_database_path --pdb70_database_path=$pdb70_database_path --pdb_seqres_database_path=$pdb_seqres_database_path --template_mmcif_dir=$template_mmcif_dir --max_template_date=$max_template_date --obsolete_pdbs_path=$obsolete_pdbs_path --db_preset=$db_preset --model_preset=$model_preset --benchmark=$benchmark --models_to_relax=$models_to_relax --use_gpu_relax=$use_gpu_relax --recycling=$recycling --run_feature=$run_feature --logtostderr

I think I also have 256GB of memory

gauravdiwan89 commented 1 year ago

I seemed to have solved the issue with the version of jax and jaxlib. Now I do not get the rocm and plugin errors. But the run still gets killed at line 244 of ./run_alphafold.sh. I will now try and check if any of the arguments for the command are problematic.

gauravdiwan89 commented 1 year ago

I was finally only able to run the python script run_alphafold.py using the following parameters

python run_alphafold.py \
--fasta_paths=../alphafold_input/IFT57.fasta \
--output_dir=../alphafold_output \
--parameter_path=../alphafold_data/params/  \
--uniref90_database_path=../alphafold_data/uniref90/uniref90.fasta \
--mgnify_database_path=../alphafold_data/mgnify/mgy_clusters_2018_12.fa \
--template_mmcif_dir=../alphafold_data/pdb_mmcif/mmcif_files/ \
--obsolete_pdbs_path=../alphafold_data/pdb_mmcif/obsolete.dat \
--bfd_database_path=../alphafold_data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
--uniref30_database_path=../alphafold_data/uniclust30/uniclust30_2020_06/UniRef30_2020_06 \
--pdb70_database_path=../alphafold_data/pdb70/pdb70 \
--max_template_date='1800-01-01' \
--use_gpu_relax=True

I don't know where the error is coming from when I run the bash script. The environment variables seem fine.

My jax and jaxlib versions are the latest - 0.4.4 and 0.4.4+cuda11.cudnn86 respectively