Closed neon5d closed 4 years ago
import jaxnet os.environ["CUDA_VISIBLE_DEVICES"] = "2" => not working
Setting environment variables before importing jax or jaxnet should fix this.
jax
jaxnet
import jaxnet os.environ["CUDA_VISIBLE_DEVICES"] = "2" => not working