Open xianyu-123 opened 7 months ago
jax is crucial, to test whether jax is properly configured, you can use the flowwing code for testing: `import os os.environ["XLA_FLAGS"]="--xla_gpu_force_compilation_parallelism=1" import jax import jax.numpy as jnp jax.config.update("jax_enable_x64", True)
num_rows = 5 num_cols = 10 smf = jnp.array([jnp.inf, 0.1, 0.1, 0.1, 0.1]) par_init = jnp.array([1.0,2.0,3.0,4.0,5.0]) lb = jnp.array([0.1, 0.1, 0.1, 0.1, 0.1]) ub = jnp.array([10.0, 10.0, 10.0, 10.0, 10.0]) par = jnp.broadcast_to(par_init[:,None],(num_rows,num_cols))
kvals = jnp.where(jnp.isinf(smf), 1, num_cols) kvals = jnp.insert(kvals, 0, 0) kvals = list(jnp.cumsum(kvals))
par0_col = jnp.zeros(num_rowsnum_cols - (num_cols-1) jnp.sum(jnp.isinf(smf))) lb_col = jnp.zeros(num_rowsnum_cols - (num_cols-1) jnp.sum(jnp.isinf(smf))) ub_col = jnp.zeros(num_rowsnum_cols- (num_cols-1) jnp.sum(jnp.isinf(smf)))
for i in range(num_rows): par0_col = par0_col.at[kvals[i]:kvals[i+1]].set(par[i, :kvals[i+1]-kvals[i]]) lb_col = lb_col.at[kvals[i]:kvals[i+1]].set(lb[i]) ub_col = ub_col.at[kvals[i]:kvals[i+1]].set(ub[i])
par_log = jnp.log10((par0_col - lb_col) / (1 - par0_col / ub_col))
@jax.jit
def compute(p):
arr_1 = jnp.zeros(shape = (num_rows, num_cols))
arr_2 = jnp.zeros(shape = (num_rows, num_cols))
for i in range(num_rows):
arr_1 = arr_1.at[i, :].set((par_log[kvals[i]:kvals[i+1]]))
arr_2 = arr_2.at[i, :].set(10**par_log[kvals[i]:kvals[i+1]])
return arr_1
arr = compute(par_log) print(arr)`
I used this installation environment without any errors throughout the process, and you can also do the same. I am using cuda11 and cudnn8.2, python version 3.8。jax_releases can found here https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
in addition,it is necessary to add export XLA_FLAGS="--xla_gpu_force_compilation_parallelism=1" in run_alphafold.sh after 162 line
absl-py==2.1.0 asttokens==2.4.1 astunparse==1.6.3 backcall==0.2.0 biopython==1.79 Bottleneck==1.3.7 Brotli==1.0.9 cachetools==5.2.0 certifi==2024.2.2 charset-normalizer==2.1.1 chex==0.0.7 comm==0.2.2 contextlib2==21.6.0 cycler==0.11.0 debugpy==1.8.1 decorator==5.1.1 dm-haiku==0.0.9 dm-tree==0.1.6 docker==5.0.0 etils==0.7.1 executing==2.0.1 flatbuffers==1.12 fonttools==4.34.4 gast==0.4.0 google-auth==2.11.0 google-auth-oauthlib==0.4.6 google-pasta==0.2.0 grpcio==1.34.1 h5py==3.1.0 idna==3.4 immutabledict==2.0.0 importlib-metadata==4.12.0 importlib-resources==5.9.0 ipykernel==6.29.4 ipython==8.12.3 jax==0.3.25 jaxlib==0.3.25+cuda11.cudnn82 jedi==0.19.1 jmp==0.0.4 jupyter_client==8.6.1 jupyter_core==5.7.2 keras==2.9.0 keras-nightly==2.5.0.dev2021032900 Keras-Preprocessing==1.1.2 kiwisolver==1.4.4 libclang==14.0.6 Markdown==3.4.1 MarkupSafe==2.1.1 matplotlib==3.5.2 matplotlib-inline==0.1.7 ml-collections==0.1.0 ml-dtypes==0.2.0 mock==4.0.3 nest-asyncio==1.6.0 nnlib==0.1 numexpr==2.8.4 numpy==1.22.4 nvidia-htop==1.0.5 oauthlib==3.2.0 OpenMM==7.5.1 opt-einsum==3.3.0 packaging==23.2 pandas==2.0.3 parso==0.8.4 pdbfixer==1.7 pexpect==4.9.0 pickleshare==0.7.5 Pillow==9.2.0 pip==23.0.1 platformdirs==3.10.0 pooch==1.7.0 prompt-toolkit==3.0.43 protobuf==3.19.4 psutil==5.9.8 ptyprocess==0.7.0 pure-eval==0.2.2 pyasn1==0.4.8 pyasn1-modules==0.2.8 Pygments==2.17.2 pyparsing==3.0.9 PySocks==1.7.1 python-dateutil==2.8.2 pytz==2023.3.post1 PyYAML==6.0 pyzmq==26.0.0 requests==2.31.0 requests-oauthlib==1.3.1 rsa==4.9 scipy==1.7.0 setuptools==65.6.3 six==1.16.0 stack-data==0.6.3 tabulate==0.8.10 tensorboard==2.9.1 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.9.0 tensorflow-estimator==2.9.0 tensorflow-io-gcs-filesystem==0.26.0 termcolor==1.1.0 toolz==0.12.0 tornado==6.4 traitlets==5.14.2 typing-extensions==3.7.4.3 tzdata==2023.3 urllib3==2.1.0 wcwidth==0.2.13 websocket-client==1.3.3 Werkzeug==2.2.2 wheel==0.38.4 wrapt==1.12.1 zipp==3.8.1