Closed ghost closed 3 years ago
seems im having a jax installation issue
This is a standard message. By default, jax first attempts to run on TPU, then if it can't find one (which the second and third line show), it attempts to run on GPU and then CPU.
>>> import jax
>>> jax.local_devices()
will show what devices jax is running on.
I'm trying to run this example (JAX branch):
At
train.train(cfg)
, the code seems to be running on TPU by default, how to change it to run on a single GPU instead?