Closed whopezx closed 2 months ago
I try in nvdia A100, the program can run correctly. On A30 may have some setting, stop the program.
You might be hitting an OOM issue or an XLA compilation issue, hard to know without careful debugging on your own setup.
Hello,
When use nvidia A30 run the code, I find the program will stuck in
pretrain.py
after this function (define in functionmake_pretrain_step
) return but use CPU will not.I use
pdb
check where the code stuck, finally find after this function return, program will call some jax internal code. Until this line, the program is stuck.I use
top
check the program is sl+. I wonder this may be because my jax installation is not correct? Below is jax and jaxlib version :Because cuda version is 12.1, use jax 0.4.12. And install jax use
pip install jax==0.4.12
, install jaxlib usepip install ./jaxlib-0.4.12+cuda12.cudnn89-cp311-cp311-manylinux2014_x86_64.whl