jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.09k stars 2.75k forks source link

TPU not detected by jax in Kaggle notebook #5031

Closed yurilla56 closed 3 years ago

yurilla56 commented 3 years ago

Hello, could you pls give advice about why TPU not detected by JAX in Kaggle notebook. Below is my code

!pip install --upgrade "https://storage.googleapis.com/jax-releases/tpu/jaxlib-0.1.55+tpu20200928-cp37-none-manylinux2010_x86_64.whl"
!pip install --upgrade jax
import jax
import os

jax.config.update("jax_xla_backend", "tpu_driver")
jax.config.update("jax_backend_target", os.environ["TPU_NAME"])

print(jax.devices())

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-11-e5d92d94d629> in <module>
      5 jax.config.update("jax_backend_target", os.environ["TPU_NAME"])
      6 
----> 7 print(jax.devices())

/opt/conda/lib/python3.7/site-packages/jax/lib/xla_bridge.py in devices(backend)
    222     List of Device subclasses.
    223   """
--> 224   return get_backend(backend).devices()
    225 
    226 

/opt/conda/lib/python3.7/site-packages/jax/lib/xla_bridge.py in get_backend(platform)
    170       msg = 'Unknown jax_xla_backend value "{}".'
    171       raise ValueError(msg.format(FLAGS.jax_xla_backend))
--> 172     return backend(platform)
    173 
    174 

/opt/conda/lib/python3.7/site-packages/jax/lib/xla_bridge.py in _get_tpu_driver_backend(***failed resolving arguments***)
    148       raise ValueError('When using TPU Driver as the backend, you must specify '
    149                        '--jax_backend_target=<hostname>:8470.')
--> 150     _tpu_backend = tpu_client.TpuBackend.create(worker=backend_target)
    151   return _tpu_backend
    152 

/opt/conda/lib/python3.7/site-packages/jaxlib/tpu_client.py in create(worker, force)
     57     else:
     58       # We do not cache for non-local backends.
---> 59       return _tpu_client.TpuClient.Get(worker)

RuntimeError: Unimplemented: Failed to connect to remote server at address: grpc://10.0.0.2:8470. Error from gRPC: . Details: 
yurilla56 commented 3 years ago

@skye Hello, pls let me to know if you need more data for troubleshooting.

jakevdp commented 3 years ago

I believe you need to first configure the TPU by sending a POST request with the desired tpu driver version.

Here is an example of how the TPU is configured in Colab; I believe Kaggle has similar requirements: https://github.com/google/jax/blob/master/cloud_tpu_colabs/JAX_demo.ipynb

yurilla56 commented 3 years ago

Hello, @jakevdp, thank you for feedback. How I should define the desired tpu driver version?

jakevdp commented 3 years ago

You can follow the example in the notebook I linked to. I'm not sure what the relevant environment variables are in Kaggle, but that shouldn't be too hard to figure out.

yurilla56 commented 3 years ago

@jakevdp Thank you! Looks work now

8bitmp3 commented 3 years ago

@yurilla56 Will you be able to share the workaround here? Also stuck here 😃

yurilla56 commented 3 years ago
import os
if 'TPU_NAME' in os.environ:
  import requests
  if 'TPU_DRIVER_MODE' not in globals():
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    TPU_DRIVER_MODE = 1

  from jax.config import config
  config.FLAGS.jax_xla_backend = "tpu_driver"
  config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
  print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
  print('No TPU detected. Can be changed under "Runtime/Change runtime type".')
8bitmp3 commented 3 years ago

Thanks @yurilla56 got the TPU at grpc://10.X.X.X::84XX

Hi @jakevdp @skye Something is going on after running a basic JAX op - a RuntimeError - check the end of the output ⬇️:

import jax
import jax.numpy as jnp
from jax import random

key = random.PRNGKey(0)
x = random.normal(key, (5000, 5000))

Output:

---------------------------------------------------------------------------
FilteredStackTrace                        Traceback (most recent call last)
<ipython-input-10-ba5b12a21b0a> in <module>
      1 key = random.PRNGKey(0)
----> 2 x = random.normal(key, (5000, 5000))

/opt/conda/lib/python3.7/site-packages/jax/random.py in normal(key, shape, dtype)
    643   shape = abstract_arrays.canonicalize_shape(shape)
--> 644   return _normal(key, shape, dtype)  # type: ignore
    645 

FilteredStackTrace: RuntimeError: Failed precondition: Expected comparison type FLOAT or TOTALORDER.
actual: UNSIGNED
operand: f32[5000,5000]

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-10-ba5b12a21b0a> in <module>
      1 key = random.PRNGKey(0)
----> 2 x = random.normal(key, (5000, 5000))

/opt/conda/lib/python3.7/site-packages/jax/random.py in normal(key, shape, dtype)
    642   dtype = dtypes.canonicalize_dtype(dtype)
    643   shape = abstract_arrays.canonicalize_shape(shape)
--> 644   return _normal(key, shape, dtype)  # type: ignore
    645 
    646 @partial(jit, static_argnums=(1, 2))

/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    131   def reraise_with_filtered_traceback(*args, **kwargs):
    132     try:
--> 133       return fun(*args, **kwargs)
    134     except Exception as e:
    135       if not is_under_reraiser(e):

/opt/conda/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    222         backend=backend,
    223         name=flat_fun.__name__,
--> 224         donated_invars=donated_invars)
    225     return tree_unflatten(out_tree(), out)
    226 

/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1187 
   1188   def bind(self, fun, *args, **params):
-> 1189     return call_bind(self, fun, *args, **params)
   1190 
   1191   def process(self, trace, fun, tracers, params):

/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1178   tracers = map(top_trace.full_raise, args)
   1179   with maybe_new_sublevel(top_trace):
-> 1180     outs = primitive.process(top_trace, fun, tracers, params)
   1181   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1182 

/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1190 
   1191   def process(self, trace, fun, tracers, params):
-> 1192     return trace.process_call(self, fun, tracers, params)
   1193 
   1194   def post_process(self, trace, out_tracers, params):

/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    581 
    582   def process_call(self, primitive, f, tracers, params):
--> 583     return primitive.impl(f, *tracers, **params)
    584   process_map = process_call
    585 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    559 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
    560   compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 561                                *unsafe_map(arg_spec, args))
    562   try:
    563     return compiled_fun(*args)

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
    249       fun.populate_stores(stores)
    250     else:
--> 251       ans = call(fun, *args)
    252       cache[key] = (ans, fun.stores)
    253 

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    705       device_assignment=(device.id,) if device else None)
    706   options.parameter_is_tupled_arguments = tuple_args
--> 707   compiled = backend_compile(backend, built, options)
    708   if nreps == 1:
    709     return partial(_execute_compiled, compiled, out_avals, result_handlers)

/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in backend_compile(backend, built_c, options)
    344   # we use a separate function call to ensure that XLA compilation appears
    345   # separately in Python profiling results
--> 346   return backend.compile(built_c, compile_options=options)
    347 
    348 def _execute_compiled_primitive(prim, compiled, result_handler, *args):

RuntimeError: Failed precondition: Expected comparison type FLOAT or TOTALORDER.
actual: UNSIGNED
operand: f32[5000,5000]
8bitmp3 commented 3 years ago

@jakevdp @skye Update: I've reset the runtime and installed the stable jaxlib with PiPy (pip install jaxlib - jaxlib-0.1.57-cp37-none-manylinux2010_x86_64.whl), instead of jaxlib-0.1.55+tpu20200928-cp37-none-manylinux2010_x86_64.whl.

Is jaxlib-0.1.57 TPU-compatible? I got jnp.dot(x, x).block_until_ready() to work OK now, no RuntimeError.