Closed ingra14m closed 9 months ago
Thanks for this! If you want to push an update to the README with this, feel free.
Thanks so much. I recognize that everyone's computer environment is different. As JAX is a framework highly dependent on the CUDA environment, it's very likely that my setup may not be suitable for all CUDA 11.6 environments.
So, thank you very much, posting it in the issue is ok. My hope is that this environment can assist some developers working in CUDA 11.6 environments.
Yeah... I may write the same thing using CUDA 11.3 (and cuDNN 8.3), but with a run time error t_t.
conda create -y --name zipnerf047 python=3.11
conda activate zipnerf047
wget https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.4.6+cuda11.cudnn82-cp311-cp311-manylinux2014_x86_64.whl
pip install jaxlib-0.4.7+cuda11.cudnn82-cp311-cp311-manylinux2014_x86_64.whl jax==0.4.7 flax opencv-python==4.9.0.80 pillow==10.2.0 tensorboard==2.15.1 tensorflow==2.15.0.post1 gin-config==0.5.0 dm-pix==0.4.2 rawpy==0.19.0 mediapy==1.2.0 immutabledict==4.1.0 ml_collections jaxcam==0.1.1 chex ffmpeg
# flax==0.7.4, chex==0.1.82 are chosen by pip automatically
When training on mipnerf360 dataset, iteration starts and psnr increases until iteration reaches 10000 (config.checkpoint_every
), which is the time of saving checkpoint:
File "xxx/codes/train.py", line 491, in main
checkpoints.save_checkpoint_multiprocess(
File "xxx/miniconda3/envs/zipnerf047/lib/python3.11/site-packages/flax/training/checkpoints.py", line 806, in save_checkpoint_multiprocess
ocp. PyTreeCheckpointHandler(restore_with_serialized_types=False)
TypeError: PyTreeCheckpointHandler.__init_() got an unexpected keyword argument 'restore_with_serialized_types'
Do you have any idea on how to solve this problem?
Hi, I know this problem. Your flax version is too high. You may need to reinstall flax==0.6.1 and chex==0.1.7. After that, you may need to replace flax.core.copy
with flax.core.FrozenDict.copy
in internal/train_utils.py
Huh weird, feel free to submit a PR that fixes this.
I try the package versions given by you. And now I can train zipnerf and save checkpoints using cuda 11.3. Thank you!
However, it is hard for me to imagine how much effort you have taken to derive all these versions... and fix bugs in codes. I still think that having cuda 11.8 installed on the machine or cluster is a better choice...
Some notes when using cuda 11.8 on a cluster:
nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
provided by nvidia.docker.mirrors.sjtug.sjtu.edu.cn/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
. reference: https://gist.github.com/y0ngb1n/7e8f16af3242c7815e7ca2f0833d3ea6Very glad that you solved the environment issue @Yang-Xijie .
@jonbarron I think the bug with saving checkpoints originates from the version of Flax. Since I used a lower version of JAX to adapt to CUDA 11.6, I also needed to downgrade the version of Flax to install it properly. As for the PR, I don’t think it’s necessary to write a specific environment configuration for CUDA 11.3-CUDA 11.7 in README. But it couldn't be better if you could link to this solution in the README to help users with CUDA 11.3-CUDA 11.7.
Thx once again for this outstanding work.
Thanks for the authors' outstanding work. However, due to JAX's high dependency on CUDA, the default environment is only compatible with CUDA 11.8 and CUDA 12.2. Here, I provide an environment that can be applied to CUDA 11.3-CUDA 11.7.
Then, pip install the following requirements:
Finally, it's necessary to install the CUDA version of jaxlib to enable GPU-accelerated training.
After that, you may need to replace
flax.core.copy
withflax.core.FrozenDict.copy
ininternal/train_utils.py
to save checkpoints correctly.