jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

TPU Error #4206

Closed YannickWehr closed 4 years ago

YannickWehr commented 4 years ago

Hi, I attempted to get jax to run on TPU as described here: https://github.com/google/jax/tree/master/cloud_tpu_colabs#running-jax-on-a-cloud-tpu-from-a-gce-vm

But I received the following error after attempting to run key = random.PRNGKey(0): Traceback (most recent call last): File "", line 1, in File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/random.py", line 65, in PRNGKey k1 = convert(onp.bitwise_and(onp.right_shift(seed, 32), 0xFFFFFFFF)) File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/random.py", line 61, in convert = lambda k: lax.reshape(lax.convert_element_type(k, onp.uint32), [1]) File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/lax/lax.py", line 380, in convert_element_type operand, new_dtype=new_dtype, old_dtype=old_dtype) File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/core.py", line 196, in bind return self.impl(*args, *kwargs) File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/interpreters/xla.py", line 166, in apply_primitive return compiled_fun(args) File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/interpreters/xla.py", line 252, in _execute_compiled_primitive out_bufs = compiled.Execute(input_bufs, tuple_arguments=tuple_args) TypeError: Execute(): incompatible function arguments. The following argument types are supported:

  1. (self: jaxlib.tpu_client_extension.TpuExecutable, arguments: Span[jaxlib.tpu_client_extension.PyTpuBuffer]) -> StatusOr[jaxlib.tpu_client_extension.PyTpuBuffer]

Invoked with: <jaxlib.tpu_client_extension.TpuExecutable object at 0x7fa146825928>, [<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7fa146d181b8>]; kwargs: tuple_arguments=False

jekbradbury commented 4 years ago

Can you provide your jax and jaxlib versions, and confirm that you're using tpu_driver_nightly as your TPU version? At first glance this looks like a version mismatch.

YannickWehr commented 4 years ago

When re-checking the versions I noticed I used the pre-installed Python which is 3.5. Upgrading to 3.7 solved the issue. I think it would make sense to update the tutorial to use an image that uses Python 3.7 as default, it currently installs tf 1.14:

export ZONE=us-central1-c gcloud compute instances create $USER-user-vm-0001 \ --machine-type=n1-standard-1 \ --image-project=ml-images \ --image-family=tf-1-14 \ --boot-disk-size=200GB \ --scopes=cloud-platform \ --zone=$ZONE

Setting image-family to tf2-2 for example uses Python 3.7.