phbradley / alphafold_finetune

Python code for fine-tuning AlphaFold to perform protein-peptide binding predictions
Apache License 2.0
137 stars 19 forks source link

version issues of Haiku/Jax #2

Closed wangy9711 closed 1 year ago

wangy9711 commented 2 years ago

Hi, thanks for your great work! But when I try to run the finetune command with the default af2 conda environment, I found some errors that may be due to version issues of Haiku/Jax. So can you tell me the version of the dependency library? It would be great if you could add env.yaml or requirement.txt to the code! Thanks again.

wangy9711 commented 2 years ago

In addition, I chose a specific version of the dependent library to allow the code to run, but found that the GPU memory took up 73G during the training process. Is this normal?

ena2016 commented 1 year ago

Hey what version of jax did you use to have the program running? I have the same issue with the new jax version changing its syntax.

Hi - index_update and friends were removed in jax version 0.3.2; see the CHANGELOG here: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-2-march-16-2022

Instead of ops.index_update(x, idx, vals) you should use x.at[idx].set(vals).

wangy9711 commented 1 year ago

Hey what version of jax did you use to have the program running? I have the same issue with the new jax version changing its syntax.

Hi - index_update and friends were removed in jax version 0.3.2; see the CHANGELOG here: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-2-march-16-2022 Instead of ops.index_update(x, idx, vals) you should use x.at[idx].set(vals).

I think here is the version of some key third-party library, under this version I can run finetune normally~

jax 0.2.19 jaxlib 0.1.69+cuda111 dm-haiku 0.0.5 dm-tree 0.1.6 jmp 0.0.2 optax 0.1.3

MeiMunick commented 1 year ago

Hey Wang, I was stuck when run Fine-tuning peptide-MHC (either on a tiny dataset or just full model) as instructed on author's github page and got : FileNotFoundError: [Errno 2] No such file or directory: '/home/pbradley/csdat/alphafold/data/params/params_model_2_ptm.npz' Have you ever had the same problem? Thank you for helping.

phbradley commented 1 year ago

Hi all, thanks for the helpful discussions and sorry I was not in town to reply sooner. For the above error, you need to use the --data_dir command line flag and provide the path to the directory that contains the AlphaFold params/ folder. I've updated the README to clarify this.

phbradley commented 1 year ago

With regard to memory usage, in our experience the training should not take more that 11-12 Gb of GPU memory.

phbradley commented 1 year ago

Sorry for the trouble with jax/python/conda environments. I am not an expert in this, and it took some work to get things to work on our machines originally, but I think much of that was specific to the GPUs that we were using and the various versions of CUDA and associated libraries. Plus it took a combination of conda and pip and some more conda and some more pip. So, that's all to say that I'm not sure that our requirements.txt would be generally useful, and I'm not even sure it would contain all the relevant information given the conda/pip mixing.

Nevertheless, our jax version is 0.2.22

Here is the result of running conda list --explicit

# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
@EXPLICIT
https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2021.10.8-ha878542_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.36.1-hea4e1c9_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-11.2.0-he4da1e4_11.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libgomp-11.2.0-h1d223b6_11.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-1_gnu.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-11.2.0-h1d223b6_11.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cudatoolkit-11.1.74-h6bb024c_0.tar.bz2
https://conda.anaconda.org/bioconda/linux-64/hmmer-3.3.2-h1b792b2_1.tar.bz2
https://conda.anaconda.org/bioconda/linux-64/kalign2-2.04-h779adbc_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h9c3ff4c_4.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.0-h7f98852_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.11-h36c2ea0_1013.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.2-h58526e2_4.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/openssl-3.0.0-h7f98852_1.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/perl-5.26.2-h36c2ea0_1008.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.5-h516909a_1.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cudnn-8.0.4-cuda11.1_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/readline-8.1-h46c0cb4_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.11-h36c2ea0_1013.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.36.0-h9cd32fc_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.11-h27826a3_1.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/python-3.8.12-hf930737_2_cpython.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.8-2_cp38.tar.bz2
https://conda.anaconda.org/conda-forge/noarch/wheel-0.37.0-pyhd8ed1ab_1.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/cudatoolkit-dev-11.4.0-py38h497a2fe_2.tar.bz2
https://conda.anaconda.org/bioconda/linux-64/hhsuite-3.3.0-py38pl5262hc37a69a_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/setuptools-58.2.0-py38h578d9bd_0.tar.bz2
https://conda.anaconda.org/conda-forge/noarch/pip-21.3-pyhd8ed1ab_0.tar.bz2

(not sure why that does't show jax!?!)

and here is what conda --list gives. This one does seem to show jax:

# packages in environment at /home/pbradley/anaconda2/envs/af2:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                       1_gnu    conda-forge
absl-py                   0.13.0                   pypi_0    pypi
alphafold                 2.1.0                    pypi_0    pypi
astunparse                1.6.3                    pypi_0    pypi
backcall                  0.2.0                    pypi_0    pypi
biopython                 1.79                     pypi_0    pypi
ca-certificates           2021.10.8            ha878542_0    conda-forge
cachetools                4.2.4                    pypi_0    pypi
certifi                   2021.10.8                pypi_0    pypi
charset-normalizer        2.0.7                    pypi_0    pypi
chex                      0.0.7                    pypi_0    pypi
contextlib2               21.6.0                   pypi_0    pypi
cudatoolkit               11.1.74              h6bb024c_0    nvidia
cudatoolkit-dev           11.4.0           py38h497a2fe_2    conda-forge
cudnn                     8.0.4                cuda11.1_0    nvidia
decorator                 5.1.0                    pypi_0    pypi
dm-haiku                  0.0.4                    pypi_0    pypi
dm-tree                   0.1.6                    pypi_0    pypi
docker                    5.0.0                    pypi_0    pypi
flatbuffers               1.12                     pypi_0    pypi
gast                      0.4.0                    pypi_0    pypi
google-auth               2.3.0                    pypi_0    pypi
google-auth-oauthlib      0.4.6                    pypi_0    pypi
google-pasta              0.2.0                    pypi_0    pypi
grpcio                    1.34.1                   pypi_0    pypi
h5py                      3.1.0                    pypi_0    pypi
hhsuite                   3.3.0           py38pl5262hc37a69a_2    bioconda
hmmer                     3.3.2                h1b792b2_1    bioconda
idna                      3.3                      pypi_0    pypi
immutabledict             2.0.0                    pypi_0    pypi
ipython                   7.28.0                   pypi_0    pypi
jax                       0.2.22                   pypi_0    pypi
jaxlib                    0.1.72+cuda111           pypi_0    pypi
jedi                      0.18.0                   pypi_0    pypi
kalign2                   2.04                 h779adbc_2    bioconda
keras-nightly             2.5.0.dev2021032900          pypi_0    pypi
keras-preprocessing       1.1.2                    pypi_0    pypi
ld_impl_linux-64          2.36.1               hea4e1c9_2    conda-forge
libffi                    3.4.2                h9c3ff4c_4    conda-forge
libgcc-ng                 11.2.0              h1d223b6_11    conda-forge
libgomp                   11.2.0              h1d223b6_11    conda-forge
libnsl                    2.0.0                h7f98852_0    conda-forge
libstdcxx-ng              11.2.0              he4da1e4_11    conda-forge
libzlib                   1.2.11            h36c2ea0_1013    conda-forge
markdown                  3.3.4                    pypi_0    pypi
matplotlib-inline         0.1.3                    pypi_0    pypi
ml-collections            0.1.0                    pypi_0    pypi
ncurses                   6.2                  h58526e2_4    conda-forge
numpy                     1.19.5                   pypi_0    pypi
oauthlib                  3.1.1                    pypi_0    pypi
openssl                   3.0.0                h7f98852_1    conda-forge
opt-einsum                3.3.0                    pypi_0    pypi
pandas                    1.3.4                    pypi_0    pypi
parso                     0.8.2                    pypi_0    pypi
perl                      5.26.2            h36c2ea0_1008    conda-forge
pexpect                   4.8.0                    pypi_0    pypi
pickleshare               0.7.5                    pypi_0    pypi
pip                       21.3               pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.20                   pypi_0    pypi
protobuf                  3.18.1                   pypi_0    pypi
ptyprocess                0.7.0                    pypi_0    pypi
pyasn1                    0.4.8                    pypi_0    pypi
pyasn1-modules            0.2.8                    pypi_0    pypi
pygments                  2.10.0                   pypi_0    pypi
python                    3.8.12          hf930737_2_cpython    conda-forge
python-dateutil           2.8.2                    pypi_0    pypi
python_abi                3.8                      2_cp38    conda-forge
pytz                      2021.3                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
readline                  8.1                  h46c0cb4_0    conda-forge
requests                  2.26.0                   pypi_0    pypi
requests-oauthlib         1.3.0                    pypi_0    pypi
rsa                       4.7.2                    pypi_0    pypi
scipy                     1.7.0                    pypi_0    pypi
setuptools                58.2.0           py38h578d9bd_0    conda-forge
six                       1.15.0                   pypi_0    pypi
sqlite                    3.36.0               h9cd32fc_2    conda-forge
tabulate                  0.8.9                    pypi_0    pypi
tensorboard               2.7.0                    pypi_0    pypi
tensorboard-data-server   0.6.1                    pypi_0    pypi
tensorboard-plugin-wit    1.8.0                    pypi_0    pypi
tensorflow-cpu            2.5.0                    pypi_0    pypi
tensorflow-estimator      2.5.0                    pypi_0    pypi
termcolor                 1.1.0                    pypi_0    pypi
tk                        8.6.11               h27826a3_1    conda-forge
toolz                     0.11.1                   pypi_0    pypi
traitlets                 5.1.0                    pypi_0    pypi
typing-extensions         3.7.4.3                  pypi_0    pypi
urllib3                   1.26.7                   pypi_0    pypi
wcwidth                   0.2.5                    pypi_0    pypi
websocket-client          1.2.1                    pypi_0    pypi
werkzeug                  2.0.2                    pypi_0    pypi
wheel                     0.37.0             pyhd8ed1ab_1    conda-forge
wrapt                     1.12.1                   pypi_0    pypi
xz                        5.2.5                h516909a_1    conda-forge
zlib                      1.2.11            h36c2ea0_1013    conda-forge

Hope that helps!

wangy9711 commented 1 year ago

With regard to memory usage, in our experience the training should not take more that 11-12 Gb of GPU memory.

Thanks for this information. The main reason for the memory consumption is that I forgot to configure the remat parameter (also called memory checkpoint) when changing the code. After correct configuration, the memory consumption of the GPU is also below 10G~

jsko-arontier commented 1 year ago

Thank you for great work!

I was able to run the program with no error in a conda environment created with the following options. You can create a conda environment by creating a yml file.

name: af_finetune
channels:
  - conda-forge
  - defaults
dependencies:
  - openmm=7.5.1
  - cudatoolkit=11.1.1
  - pdbfixer
  - pip
  - python=3.8
  - pip:
    - absl-py==0.10.0
    - biopython==1.80
    - dm-haiku==0.0.5
    - dm-tree==0.1.6
    - immutabledict==2.0.0
    - jax==0.2.22
    - ml-collections==0.1.0
    - numpy==1.19.5
    - pandas==1.3.4
    - protobuf==3.20.1
    - scipy==1.7.0
    - tensorflow==2.5.0
    - tensorflow-estimator==2.5.0 
    - https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.72+cuda111-cp38-none-manylinux2010_x86_64.whl
    - chex==0.0.7
    - typing-extensions==3.7.4.3
    - optax==0.0.9
phbradley commented 1 year ago

Thanks jsko-arontier for posting that helpful info! I realized that the info on evironments that I posted above was for a slightly older variant than the one I used for the calculations in the paper. My apologies for that, here is the output of conda env export with the correct environment. I would try the one in the previous post first, since it looks cleaner, but if it doesn't work this could be an alternative:

name: af2test
channels:
  - nvidia
  - conda-forge
  - bioconda
  - defaults
  - r
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=1_gnu
  - ca-certificates=2021.10.8=ha878542_0
  - cudatoolkit=11.1.74=h6bb024c_0
  - cudatoolkit-dev=11.4.0=py38h497a2fe_2
  - cudnn=8.0.4=cuda11.1_0
  - hhsuite=3.3.0=py38pl5262hc37a69a_2
  - hmmer=3.3.2=h1b792b2_1
  - kalign2=2.04=h779adbc_2
  - ld_impl_linux-64=2.36.1=hea4e1c9_2
  - libffi=3.4.2=h9c3ff4c_4
  - libgcc-ng=11.2.0=h1d223b6_11
  - libgomp=11.2.0=h1d223b6_11
  - libnsl=2.0.0=h7f98852_0
  - libstdcxx-ng=11.2.0=he4da1e4_11
  - libzlib=1.2.11=h36c2ea0_1013
  - ncurses=6.2=h58526e2_4
  - openssl=3.0.0=h7f98852_1
  - perl=5.26.2=h36c2ea0_1008
  - pip=21.3=pyhd8ed1ab_0
  - python=3.8.12=hf930737_2_cpython
  - python_abi=3.8=2_cp38
  - readline=8.1=h46c0cb4_0
  - setuptools=58.2.0=py38h578d9bd_0
  - sqlite=3.36.0=h9cd32fc_2
  - tk=8.6.11=h27826a3_1
  - wheel=0.37.0=pyhd8ed1ab_1
  - xz=5.2.5=h516909a_1
  - zlib=1.2.11=h36c2ea0_1013
  - pip:
    - absl-py==0.13.0
    - alphafold==2.1.0
    - astunparse==1.6.3
    - backcall==0.2.0
    - biopython==1.79
    - cachetools==4.2.4
    - certifi==2021.10.8
    - charset-normalizer==2.0.7
    - chex==0.0.7
    - contextlib2==21.6.0
    - decorator==5.1.0
    - dm-haiku==0.0.6.dev0
    - dm-tree==0.1.6
    - docker==5.0.0
    - flatbuffers==1.12
    - gast==0.4.0
    - google-auth==2.3.0
    - google-auth-oauthlib==0.4.6
    - google-pasta==0.2.0
    - grpcio==1.34.1
    - h5py==3.1.0
    - idna==3.3
    - immutabledict==2.0.0
    - ipython==7.28.0
    - jax==0.2.22
    - jaxlib==0.1.72+cuda111
    - jedi==0.18.0
    - jmp==0.0.2
    - keras-nightly==2.5.0.dev2021032900
    - keras-preprocessing==1.1.2
    - markdown==3.3.4
    - matplotlib-inline==0.1.3
    - ml-collections==0.1.0
    - numpy==1.19.5
    - oauthlib==3.1.1
    - opt-einsum==3.3.0
    - optax==0.1.0
    - pandas==1.3.4
    - parso==0.8.2
    - pexpect==4.8.0
    - pickleshare==0.7.5
    - prompt-toolkit==3.0.20
    - protobuf==3.18.1
    - ptyprocess==0.7.0
    - pyasn1==0.4.8
    - pyasn1-modules==0.2.8
    - pygments==2.10.0
    - python-dateutil==2.8.2
    - pytz==2021.3
    - pyyaml==6.0
    - requests==2.26.0
    - requests-oauthlib==1.3.0
    - rsa==4.7.2
    - scipy==1.7.0
    - six==1.15.0
    - tabulate==0.8.9
    - tensorboard==2.7.0
    - tensorboard-data-server==0.6.1
    - tensorboard-plugin-wit==1.8.0
    - tensorflow-cpu==2.5.0
    - tensorflow-estimator==2.5.0
    - termcolor==1.1.0
    - toolz==0.11.1
    - torch==1.10.1
    - traitlets==5.1.0
    - typing-extensions==3.10.0.2
    - urllib3==1.26.7
    - wcwidth==0.2.5
    - websocket-client==1.2.1
    - werkzeug==2.0.2
    - wrapt==1.12.1