Closed miss-rain closed 2 years ago
Hi,
Thanks for your message. Could you please provide more detailed error message so that we can figure it out easier?
I have follow your code,but it does not work on my GPU it only run in CPU
I have follow your code,but it does not work on my GPU it only run in CPU @miss-rain, hi, I meet the same problem. Have you solved it?
I just checked this issue again. I think the root problem might be JAX could not find your GPU due to version mismatch. I have updated the README about how to download a correct JAX version that corresponds to your CUDA version, please have a try, thanks!
Thanks for update and replay, its long time(^.^)
I have solved the problem of jax for Nvidia GPU.
but now
I don't run completely in my 4 Telsa T4 (GPU), the error call 'out of memory', even i change batchsize=1
could you release a lite version?
Thanks again.
@miss-rain I found in the official document that when JAX executes the first JAX command, it pre-allocates 90% of the available GPU memory. As described in the document, I can either disable the pre-allocation or reduce the pre-allocation ratio to run the ViT-Base model. As a result, I was able to run on my GPUs 8 RTX 3090 and 8 A5000.
This code can not run in requirement.txt,
jaxlib and cuda and cudd is fine in my ubuntu.
please check again!