Closed YannickWehr closed 4 years ago
Can you provide your jax
and jaxlib
versions, and confirm that you're using tpu_driver_nightly
as your TPU version? At first glance this looks like a version mismatch.
When re-checking the versions I noticed I used the pre-installed Python which is 3.5. Upgrading to 3.7 solved the issue. I think it would make sense to update the tutorial to use an image that uses Python 3.7 as default, it currently installs tf 1.14:
export ZONE=us-central1-c gcloud compute instances create $USER-user-vm-0001 \ --machine-type=n1-standard-1 \ --image-project=ml-images \ --image-family=tf-1-14 \ --boot-disk-size=200GB \ --scopes=cloud-platform \ --zone=$ZONE
Setting image-family to tf2-2 for example uses Python 3.7.
Hi, I attempted to get jax to run on TPU as described here: https://github.com/google/jax/tree/master/cloud_tpu_colabs#running-jax-on-a-cloud-tpu-from-a-gce-vm
But I received the following error after attempting to run key = random.PRNGKey(0): Traceback (most recent call last): File "", line 1, in
File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/random.py", line 65, in PRNGKey
k1 = convert(onp.bitwise_and(onp.right_shift(seed, 32), 0xFFFFFFFF))
File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/random.py", line 61, in
convert = lambda k: lax.reshape(lax.convert_element_type(k, onp.uint32), [1])
File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/lax/lax.py", line 380, in convert_element_type
operand, new_dtype=new_dtype, old_dtype=old_dtype)
File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/core.py", line 196, in bind
return self.impl(*args, *kwargs)
File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/interpreters/xla.py", line 166, in apply_primitive
return compiled_fun(args)
File "/home/yanmarka_yannick/.local/lib/python3.5/site-packages/jax/interpreters/xla.py", line 252, in _execute_compiled_primitive
out_bufs = compiled.Execute(input_bufs, tuple_arguments=tuple_args)
TypeError: Execute(): incompatible function arguments. The following argument types are supported:
Invoked with: <jaxlib.tpu_client_extension.TpuExecutable object at 0x7fa146825928>, [<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7fa146d181b8>]; kwargs: tuple_arguments=False