Open kirtishrinkhala opened 6 months ago
Check jaxlib version with your cuda.
I installed using this code.
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
if you get cuda OOM error, then use this code in front of python main.py
export XLA_PYTHON_CLIENT_PREALLOCATE=false && python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py
this worked for me.
Hi, I am following the steps in the Readme to run the model. My goal is to be able to run the model to be able to provide my inputs. I dont want to train the model.
I did the following:
On running the command-
I get the following error :
Any pointers on what is causing this?