Closed wangy9711 closed 1 year ago
In addition, I chose a specific version of the dependent library to allow the code to run, but found that the GPU memory took up 73G during the training process. Is this normal?
Hey what version of jax did you use to have the program running? I have the same issue with the new jax version changing its syntax.
Hi - index_update and friends were removed in jax version 0.3.2; see the CHANGELOG here: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-2-march-16-2022
Instead of ops.index_update(x, idx, vals) you should use x.at[idx].set(vals).
Hey what version of jax did you use to have the program running? I have the same issue with the new jax version changing its syntax.
Hi - index_update and friends were removed in jax version 0.3.2; see the CHANGELOG here: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-3-2-march-16-2022 Instead of ops.index_update(x, idx, vals) you should use x.at[idx].set(vals).
I think here is the version of some key third-party library, under this version I can run finetune normally~
jax 0.2.19 jaxlib 0.1.69+cuda111 dm-haiku 0.0.5 dm-tree 0.1.6 jmp 0.0.2 optax 0.1.3
Hey Wang, I was stuck when run Fine-tuning peptide-MHC (either on a tiny dataset or just full model) as instructed on author's github page and got :
FileNotFoundError: [Errno 2] No such file or directory: '/home/pbradley/csdat/alphafold/data/params/params_model_2_ptm.npz'
Have you ever had the same problem? Thank you for helping.
Hi all, thanks for the helpful discussions and sorry I was not in town to reply sooner. For the above error, you need to use the --data_dir
command line flag and provide the path to the directory that contains the AlphaFold params/
folder. I've updated the README to clarify this.
With regard to memory usage, in our experience the training should not take more that 11-12 Gb of GPU memory.
Sorry for the trouble with jax/python/conda environments. I am not an expert in this, and it took some work to get things to work on our machines originally, but I think much of that was specific to the GPUs that we were using and the various versions of CUDA and associated libraries. Plus it took a combination of conda
and pip
and some more conda
and some more pip
. So, that's all to say that I'm not sure that our requirements.txt
would be generally useful, and I'm not even sure it would contain all the relevant information given the conda
/pip
mixing.
Nevertheless, our jax
version is 0.2.22
Here is the result of running conda list --explicit
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
@EXPLICIT
https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/ca-certificates-2021.10.8-ha878542_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/ld_impl_linux-64-2.36.1-hea4e1c9_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libstdcxx-ng-11.2.0-he4da1e4_11.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libgomp-11.2.0-h1d223b6_11.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/_openmp_mutex-4.5-1_gnu.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libgcc-ng-11.2.0-h1d223b6_11.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cudatoolkit-11.1.74-h6bb024c_0.tar.bz2
https://conda.anaconda.org/bioconda/linux-64/hmmer-3.3.2-h1b792b2_1.tar.bz2
https://conda.anaconda.org/bioconda/linux-64/kalign2-2.04-h779adbc_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libffi-3.4.2-h9c3ff4c_4.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libnsl-2.0.0-h7f98852_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.2.11-h36c2ea0_1013.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.2-h58526e2_4.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/openssl-3.0.0-h7f98852_1.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/perl-5.26.2-h36c2ea0_1008.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/xz-5.2.5-h516909a_1.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cudnn-8.0.4-cuda11.1_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/readline-8.1-h46c0cb4_0.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/zlib-1.2.11-h36c2ea0_1013.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/sqlite-3.36.0-h9cd32fc_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.11-h27826a3_1.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/python-3.8.12-hf930737_2_cpython.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/python_abi-3.8-2_cp38.tar.bz2
https://conda.anaconda.org/conda-forge/noarch/wheel-0.37.0-pyhd8ed1ab_1.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/cudatoolkit-dev-11.4.0-py38h497a2fe_2.tar.bz2
https://conda.anaconda.org/bioconda/linux-64/hhsuite-3.3.0-py38pl5262hc37a69a_2.tar.bz2
https://conda.anaconda.org/conda-forge/linux-64/setuptools-58.2.0-py38h578d9bd_0.tar.bz2
https://conda.anaconda.org/conda-forge/noarch/pip-21.3-pyhd8ed1ab_0.tar.bz2
(not sure why that does't show jax!?!)
and here is what conda --list
gives. This one does seem to show jax
:
# packages in environment at /home/pbradley/anaconda2/envs/af2:
#
# Name Version Build Channel
_libgcc_mutex 0.1 conda_forge conda-forge
_openmp_mutex 4.5 1_gnu conda-forge
absl-py 0.13.0 pypi_0 pypi
alphafold 2.1.0 pypi_0 pypi
astunparse 1.6.3 pypi_0 pypi
backcall 0.2.0 pypi_0 pypi
biopython 1.79 pypi_0 pypi
ca-certificates 2021.10.8 ha878542_0 conda-forge
cachetools 4.2.4 pypi_0 pypi
certifi 2021.10.8 pypi_0 pypi
charset-normalizer 2.0.7 pypi_0 pypi
chex 0.0.7 pypi_0 pypi
contextlib2 21.6.0 pypi_0 pypi
cudatoolkit 11.1.74 h6bb024c_0 nvidia
cudatoolkit-dev 11.4.0 py38h497a2fe_2 conda-forge
cudnn 8.0.4 cuda11.1_0 nvidia
decorator 5.1.0 pypi_0 pypi
dm-haiku 0.0.4 pypi_0 pypi
dm-tree 0.1.6 pypi_0 pypi
docker 5.0.0 pypi_0 pypi
flatbuffers 1.12 pypi_0 pypi
gast 0.4.0 pypi_0 pypi
google-auth 2.3.0 pypi_0 pypi
google-auth-oauthlib 0.4.6 pypi_0 pypi
google-pasta 0.2.0 pypi_0 pypi
grpcio 1.34.1 pypi_0 pypi
h5py 3.1.0 pypi_0 pypi
hhsuite 3.3.0 py38pl5262hc37a69a_2 bioconda
hmmer 3.3.2 h1b792b2_1 bioconda
idna 3.3 pypi_0 pypi
immutabledict 2.0.0 pypi_0 pypi
ipython 7.28.0 pypi_0 pypi
jax 0.2.22 pypi_0 pypi
jaxlib 0.1.72+cuda111 pypi_0 pypi
jedi 0.18.0 pypi_0 pypi
kalign2 2.04 h779adbc_2 bioconda
keras-nightly 2.5.0.dev2021032900 pypi_0 pypi
keras-preprocessing 1.1.2 pypi_0 pypi
ld_impl_linux-64 2.36.1 hea4e1c9_2 conda-forge
libffi 3.4.2 h9c3ff4c_4 conda-forge
libgcc-ng 11.2.0 h1d223b6_11 conda-forge
libgomp 11.2.0 h1d223b6_11 conda-forge
libnsl 2.0.0 h7f98852_0 conda-forge
libstdcxx-ng 11.2.0 he4da1e4_11 conda-forge
libzlib 1.2.11 h36c2ea0_1013 conda-forge
markdown 3.3.4 pypi_0 pypi
matplotlib-inline 0.1.3 pypi_0 pypi
ml-collections 0.1.0 pypi_0 pypi
ncurses 6.2 h58526e2_4 conda-forge
numpy 1.19.5 pypi_0 pypi
oauthlib 3.1.1 pypi_0 pypi
openssl 3.0.0 h7f98852_1 conda-forge
opt-einsum 3.3.0 pypi_0 pypi
pandas 1.3.4 pypi_0 pypi
parso 0.8.2 pypi_0 pypi
perl 5.26.2 h36c2ea0_1008 conda-forge
pexpect 4.8.0 pypi_0 pypi
pickleshare 0.7.5 pypi_0 pypi
pip 21.3 pyhd8ed1ab_0 conda-forge
prompt-toolkit 3.0.20 pypi_0 pypi
protobuf 3.18.1 pypi_0 pypi
ptyprocess 0.7.0 pypi_0 pypi
pyasn1 0.4.8 pypi_0 pypi
pyasn1-modules 0.2.8 pypi_0 pypi
pygments 2.10.0 pypi_0 pypi
python 3.8.12 hf930737_2_cpython conda-forge
python-dateutil 2.8.2 pypi_0 pypi
python_abi 3.8 2_cp38 conda-forge
pytz 2021.3 pypi_0 pypi
pyyaml 6.0 pypi_0 pypi
readline 8.1 h46c0cb4_0 conda-forge
requests 2.26.0 pypi_0 pypi
requests-oauthlib 1.3.0 pypi_0 pypi
rsa 4.7.2 pypi_0 pypi
scipy 1.7.0 pypi_0 pypi
setuptools 58.2.0 py38h578d9bd_0 conda-forge
six 1.15.0 pypi_0 pypi
sqlite 3.36.0 h9cd32fc_2 conda-forge
tabulate 0.8.9 pypi_0 pypi
tensorboard 2.7.0 pypi_0 pypi
tensorboard-data-server 0.6.1 pypi_0 pypi
tensorboard-plugin-wit 1.8.0 pypi_0 pypi
tensorflow-cpu 2.5.0 pypi_0 pypi
tensorflow-estimator 2.5.0 pypi_0 pypi
termcolor 1.1.0 pypi_0 pypi
tk 8.6.11 h27826a3_1 conda-forge
toolz 0.11.1 pypi_0 pypi
traitlets 5.1.0 pypi_0 pypi
typing-extensions 3.7.4.3 pypi_0 pypi
urllib3 1.26.7 pypi_0 pypi
wcwidth 0.2.5 pypi_0 pypi
websocket-client 1.2.1 pypi_0 pypi
werkzeug 2.0.2 pypi_0 pypi
wheel 0.37.0 pyhd8ed1ab_1 conda-forge
wrapt 1.12.1 pypi_0 pypi
xz 5.2.5 h516909a_1 conda-forge
zlib 1.2.11 h36c2ea0_1013 conda-forge
Hope that helps!
With regard to memory usage, in our experience the training should not take more that 11-12 Gb of GPU memory.
Thanks for this information. The main reason for the memory consumption is that I forgot to configure the remat parameter (also called memory checkpoint) when changing the code. After correct configuration, the memory consumption of the GPU is also below 10G~
Thank you for great work!
I was able to run the program with no error in a conda environment created with the following options. You can create a conda environment by creating a yml file.
name: af_finetune
channels:
- conda-forge
- defaults
dependencies:
- openmm=7.5.1
- cudatoolkit=11.1.1
- pdbfixer
- pip
- python=3.8
- pip:
- absl-py==0.10.0
- biopython==1.80
- dm-haiku==0.0.5
- dm-tree==0.1.6
- immutabledict==2.0.0
- jax==0.2.22
- ml-collections==0.1.0
- numpy==1.19.5
- pandas==1.3.4
- protobuf==3.20.1
- scipy==1.7.0
- tensorflow==2.5.0
- tensorflow-estimator==2.5.0
- https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.72+cuda111-cp38-none-manylinux2010_x86_64.whl
- chex==0.0.7
- typing-extensions==3.7.4.3
- optax==0.0.9
Thanks jsko-arontier for posting that helpful info! I realized that the info on evironments that I posted above was for a slightly older variant than the one I used for the calculations in the paper. My apologies for that, here is the output of conda env export
with the correct environment. I would try the one in the previous post first, since it looks cleaner, but if it doesn't work this could be an alternative:
name: af2test
channels:
- nvidia
- conda-forge
- bioconda
- defaults
- r
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=1_gnu
- ca-certificates=2021.10.8=ha878542_0
- cudatoolkit=11.1.74=h6bb024c_0
- cudatoolkit-dev=11.4.0=py38h497a2fe_2
- cudnn=8.0.4=cuda11.1_0
- hhsuite=3.3.0=py38pl5262hc37a69a_2
- hmmer=3.3.2=h1b792b2_1
- kalign2=2.04=h779adbc_2
- ld_impl_linux-64=2.36.1=hea4e1c9_2
- libffi=3.4.2=h9c3ff4c_4
- libgcc-ng=11.2.0=h1d223b6_11
- libgomp=11.2.0=h1d223b6_11
- libnsl=2.0.0=h7f98852_0
- libstdcxx-ng=11.2.0=he4da1e4_11
- libzlib=1.2.11=h36c2ea0_1013
- ncurses=6.2=h58526e2_4
- openssl=3.0.0=h7f98852_1
- perl=5.26.2=h36c2ea0_1008
- pip=21.3=pyhd8ed1ab_0
- python=3.8.12=hf930737_2_cpython
- python_abi=3.8=2_cp38
- readline=8.1=h46c0cb4_0
- setuptools=58.2.0=py38h578d9bd_0
- sqlite=3.36.0=h9cd32fc_2
- tk=8.6.11=h27826a3_1
- wheel=0.37.0=pyhd8ed1ab_1
- xz=5.2.5=h516909a_1
- zlib=1.2.11=h36c2ea0_1013
- pip:
- absl-py==0.13.0
- alphafold==2.1.0
- astunparse==1.6.3
- backcall==0.2.0
- biopython==1.79
- cachetools==4.2.4
- certifi==2021.10.8
- charset-normalizer==2.0.7
- chex==0.0.7
- contextlib2==21.6.0
- decorator==5.1.0
- dm-haiku==0.0.6.dev0
- dm-tree==0.1.6
- docker==5.0.0
- flatbuffers==1.12
- gast==0.4.0
- google-auth==2.3.0
- google-auth-oauthlib==0.4.6
- google-pasta==0.2.0
- grpcio==1.34.1
- h5py==3.1.0
- idna==3.3
- immutabledict==2.0.0
- ipython==7.28.0
- jax==0.2.22
- jaxlib==0.1.72+cuda111
- jedi==0.18.0
- jmp==0.0.2
- keras-nightly==2.5.0.dev2021032900
- keras-preprocessing==1.1.2
- markdown==3.3.4
- matplotlib-inline==0.1.3
- ml-collections==0.1.0
- numpy==1.19.5
- oauthlib==3.1.1
- opt-einsum==3.3.0
- optax==0.1.0
- pandas==1.3.4
- parso==0.8.2
- pexpect==4.8.0
- pickleshare==0.7.5
- prompt-toolkit==3.0.20
- protobuf==3.18.1
- ptyprocess==0.7.0
- pyasn1==0.4.8
- pyasn1-modules==0.2.8
- pygments==2.10.0
- python-dateutil==2.8.2
- pytz==2021.3
- pyyaml==6.0
- requests==2.26.0
- requests-oauthlib==1.3.0
- rsa==4.7.2
- scipy==1.7.0
- six==1.15.0
- tabulate==0.8.9
- tensorboard==2.7.0
- tensorboard-data-server==0.6.1
- tensorboard-plugin-wit==1.8.0
- tensorflow-cpu==2.5.0
- tensorflow-estimator==2.5.0
- termcolor==1.1.0
- toolz==0.11.1
- torch==1.10.1
- traitlets==5.1.0
- typing-extensions==3.10.0.2
- urllib3==1.26.7
- wcwidth==0.2.5
- websocket-client==1.2.1
- werkzeug==2.0.2
- wrapt==1.12.1
Hi, thanks for your great work! But when I try to run the finetune command with the default af2 conda environment, I found some errors that may be due to version issues of Haiku/Jax. So can you tell me the version of the dependency library? It would be great if you could add env.yaml or requirement.txt to the code! Thanks again.