kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.27k stars 890 forks source link

No TPU found, falling back to CPU #201

Closed 0x7o closed 2 years ago

0x7o commented 2 years ago

I'm trying to convert the model with the "to_hf_weights.py" script. I need to load the model on the TPU. For this purpose, I use Google Colab.

The script shows a warning. WARNING:absl:No GPU/TPU found, falling back to CPU.

At the same time the script shows that the TPU is working.

import jax jax.devices()

[TpuDevice(id=0, host_id=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, host_id=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, host_id=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, host_id=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, host_id=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, host_id=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, host_id=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, host_id=0, coords=(1,1,0), core_on_chip=1)]

kingoflolz commented 2 years ago

to_hf_weights has not been tested on colab, and is designed for use on TPU-VMs (and hence does not initialize the TPUs properly for colab)