I0808 20:59:00.842178 139857030612800 xla_bridge.py:328] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker:
I0808 20:59:00.963761 139857030612800 xla_bridge.py:328] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: CUDA Host Interpreter
I0808 20:59:00.964296 139857030612800 xla_bridge.py:328] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
Is there a config setting I need to change somewhere to tell jax that I want to use a gpu, and not a tpu? Or will it fallback to gpu automatically, and I can disregard these warnings?
When I run train.py, I get these warnings:
Is there a config setting I need to change somewhere to tell jax that I want to use a gpu, and not a tpu? Or will it fallback to gpu automatically, and I can disregard these warnings?