jonbarron / camp_zipnerf

Apache License 2.0
646 stars 41 forks source link

Jax Environment for CUDA 11.6 #2

Closed ingra14m closed 7 months ago

ingra14m commented 7 months ago

Thanks for the authors' outstanding work. However, due to JAX's high dependency on CUDA, the default environment is only compatible with CUDA 11.8 and CUDA 12.2. Here, I provide an environment that can be applied to CUDA 11.3-CUDA 11.7.

conda create -n camp_zipnerf python=3.10
conda activate camp_zipnerf 

Then, pip install the following requirements:

numpy==1.26.3
jax==0.4.6
jaxlib==0.4.6
flax==0.6.1
opencv-python==4.9.0.80
pillow==10.2.0
tensorboard==2.10.1
tensorflow==2.10.0
gin-config==0.5.0
dm-pix==0.4.2
rawpy==0.19.0
mediapy==1.2.0
immutabledict==4.1.0
ml_collections
jaxcam==0.1.1
chex==0.1.7

Finally, it's necessary to install the CUDA version of jaxlib to enable GPU-accelerated training.

# python 3.10
wget https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.6+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl
pip install jaxlib-0.4.6+cuda11.cudnn82-cp310-cp310-manylinux2014_x86_64.whl

# python 3.11
wget https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.6+cuda11.cudnn82-cp311-cp311-manylinux2014_x86_64.whl
pip intall jaxlib-0.4.6+cuda11.cudnn82-cp311-cp311-manylinux2014_x86_64.whl

After that, you may need to replace flax.core.copy with flax.core.FrozenDict.copy in internal/train_utils.py to save checkpoints correctly.

jonbarron commented 7 months ago

Thanks for this! If you want to push an update to the README with this, feel free.

ingra14m commented 7 months ago

Thanks so much. I recognize that everyone's computer environment is different. As JAX is a framework highly dependent on the CUDA environment, it's very likely that my setup may not be suitable for all CUDA 11.6 environments.

So, thank you very much, posting it in the issue is ok. My hope is that this environment can assist some developers working in CUDA 11.6 environments.

Yang-Xijie commented 7 months ago

Yeah... I may write the same thing using CUDA 11.3 (and cuDNN 8.3), but with a run time error t_t.

conda create -y --name zipnerf047 python=3.11
conda activate zipnerf047

wget https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.6+cuda11.cudnn82-cp311-cp311-manylinux2014_x86_64.whl
pip install jaxlib-0.4.7+cuda11.cudnn82-cp311-cp311-manylinux2014_x86_64.whl jax==0.4.7 flax opencv-python==4.9.0.80 pillow==10.2.0 tensorboard==2.15.1 tensorflow==2.15.0.post1 gin-config==0.5.0 dm-pix==0.4.2 rawpy==0.19.0 mediapy==1.2.0 immutabledict==4.1.0 ml_collections jaxcam==0.1.1 chex ffmpeg
# flax==0.7.4, chex==0.1.82 are chosen by pip automatically

When training on mipnerf360 dataset, iteration starts and psnr increases until iteration reaches 10000 (config.checkpoint_every), which is the time of saving checkpoint:

  File "xxx/codes/train.py", line 491, in main
    checkpoints.save_checkpoint_multiprocess(
  File "xxx/miniconda3/envs/zipnerf047/lib/python3.11/site-packages/flax/training/checkpoints.py", line 806, in save_checkpoint_multiprocess
    ocp. PyTreeCheckpointHandler(restore_with_serialized_types=False)
TypeError: PyTreeCheckpointHandler.__init_() got an unexpected keyword argument 'restore_with_serialized_types'

Do you have any idea on how to solve this problem?

ingra14m commented 7 months ago

Hi, I know this problem. Your flax version is too high. You may need to reinstall flax==0.6.1 and chex==0.1.7. After that, you may need to replace flax.core.copy with flax.core.FrozenDict.copy in internal/train_utils.py

jonbarron commented 7 months ago

Huh weird, feel free to submit a PR that fixes this.

Yang-Xijie commented 7 months ago

I try the package versions given by you. And now I can train zipnerf and save checkpoints using cuda 11.3. Thank you!

However, it is hard for me to imagine how much effort you have taken to derive all these versions... and fix bugs in codes. I still think that having cuda 11.8 installed on the machine or cluster is a better choice...

Some notes when using cuda 11.8 on a cluster:

ingra14m commented 7 months ago

Very glad that you solved the environment issue @Yang-Xijie .

@jonbarron I think the bug with saving checkpoints originates from the version of Flax. Since I used a lower version of JAX to adapt to CUDA 11.6, I also needed to downgrade the version of Flax to install it properly. As for the PR, I don’t think it’s necessary to write a specific environment configuration for CUDA 11.3-CUDA 11.7 in README. But it couldn't be better if you could link to this solution in the README to help users with CUDA 11.3-CUDA 11.7.

Thx once again for this outstanding work.