google-deepmind / ferminet

An implementation of the Fermionic Neural Network for ab-initio electronic structure calculations
Apache License 2.0
740 stars 129 forks source link

how to make JAX code run on single GPU instead of TPU? #33

Closed ghost closed 3 years ago

ghost commented 3 years ago

I'm trying to run this example (JAX branch):

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

At train.train(cfg), the code seems to be running on TPU by default, how to change it to run on a single GPU instead?

INFO:absl:Starting the local TPU driver. INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local:// INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available. INFO:absl:Starting QMC with 1 XLA devices

ghost commented 3 years ago

seems im having a jax installation issue

jsspencer commented 3 years ago

This is a standard message. By default, jax first attempts to run on TPU, then if it can't find one (which the second and third line show), it attempts to run on GPU and then CPU.

>>> import jax
>>> jax.local_devices()

will show what devices jax is running on.