Closed yurilla56 closed 3 years ago
@skye Hello, pls let me to know if you need more data for troubleshooting.
I believe you need to first configure the TPU by sending a POST request with the desired tpu driver version.
Here is an example of how the TPU is configured in Colab; I believe Kaggle has similar requirements: https://github.com/google/jax/blob/master/cloud_tpu_colabs/JAX_demo.ipynb
Hello, @jakevdp, thank you for feedback. How I should define the desired tpu driver version?
You can follow the example in the notebook I linked to. I'm not sure what the relevant environment variables are in Kaggle, but that shouldn't be too hard to figure out.
@jakevdp Thank you! Looks work now
@yurilla56 Will you be able to share the workaround here? Also stuck here 😃
import os
if 'TPU_NAME' in os.environ:
import requests
if 'TPU_DRIVER_MODE' not in globals():
url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
resp = requests.post(url)
TPU_DRIVER_MODE = 1
from jax.config import config
config.FLAGS.jax_xla_backend = "tpu_driver"
config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
print('Registered TPU:', config.FLAGS.jax_backend_target)
else:
print('No TPU detected. Can be changed under "Runtime/Change runtime type".')
Thanks @yurilla56 got the TPU at grpc://10.X.X.X::84XX
Hi @jakevdp @skye Something is going on after running a basic JAX op - a RuntimeError
- check the end of the output ⬇️:
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (5000, 5000))
Output:
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-10-ba5b12a21b0a> in <module>
1 key = random.PRNGKey(0)
----> 2 x = random.normal(key, (5000, 5000))
/opt/conda/lib/python3.7/site-packages/jax/random.py in normal(key, shape, dtype)
643 shape = abstract_arrays.canonicalize_shape(shape)
--> 644 return _normal(key, shape, dtype) # type: ignore
645
FilteredStackTrace: RuntimeError: Failed precondition: Expected comparison type FLOAT or TOTALORDER.
actual: UNSIGNED
operand: f32[5000,5000]
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-10-ba5b12a21b0a> in <module>
1 key = random.PRNGKey(0)
----> 2 x = random.normal(key, (5000, 5000))
/opt/conda/lib/python3.7/site-packages/jax/random.py in normal(key, shape, dtype)
642 dtype = dtypes.canonicalize_dtype(dtype)
643 shape = abstract_arrays.canonicalize_shape(shape)
--> 644 return _normal(key, shape, dtype) # type: ignore
645
646 @partial(jit, static_argnums=(1, 2))
/opt/conda/lib/python3.7/site-packages/jax/_src/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
131 def reraise_with_filtered_traceback(*args, **kwargs):
132 try:
--> 133 return fun(*args, **kwargs)
134 except Exception as e:
135 if not is_under_reraiser(e):
/opt/conda/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
222 backend=backend,
223 name=flat_fun.__name__,
--> 224 donated_invars=donated_invars)
225 return tree_unflatten(out_tree(), out)
226
/opt/conda/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
1187
1188 def bind(self, fun, *args, **params):
-> 1189 return call_bind(self, fun, *args, **params)
1190
1191 def process(self, trace, fun, tracers, params):
/opt/conda/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1178 tracers = map(top_trace.full_raise, args)
1179 with maybe_new_sublevel(top_trace):
-> 1180 outs = primitive.process(top_trace, fun, tracers, params)
1181 return map(full_lower, apply_todos(env_trace_todo(), outs))
1182
/opt/conda/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1190
1191 def process(self, trace, fun, tracers, params):
-> 1192 return trace.process_call(self, fun, tracers, params)
1193
1194 def post_process(self, trace, out_tracers, params):
/opt/conda/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
581
582 def process_call(self, primitive, f, tracers, params):
--> 583 return primitive.impl(f, *tracers, **params)
584 process_map = process_call
585
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
559 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
560 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
--> 561 *unsafe_map(arg_spec, args))
562 try:
563 return compiled_fun(*args)
/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
249 fun.populate_stores(stores)
250 else:
--> 251 ans = call(fun, *args)
252 cache[key] = (ans, fun.stores)
253
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
705 device_assignment=(device.id,) if device else None)
706 options.parameter_is_tupled_arguments = tuple_args
--> 707 compiled = backend_compile(backend, built, options)
708 if nreps == 1:
709 return partial(_execute_compiled, compiled, out_avals, result_handlers)
/opt/conda/lib/python3.7/site-packages/jax/interpreters/xla.py in backend_compile(backend, built_c, options)
344 # we use a separate function call to ensure that XLA compilation appears
345 # separately in Python profiling results
--> 346 return backend.compile(built_c, compile_options=options)
347
348 def _execute_compiled_primitive(prim, compiled, result_handler, *args):
RuntimeError: Failed precondition: Expected comparison type FLOAT or TOTALORDER.
actual: UNSIGNED
operand: f32[5000,5000]
@jakevdp @skye Update: I've reset the runtime and installed the stable jaxlib
with PiPy (pip install jaxlib
- jaxlib-0.1.57-cp37-none-manylinux2010_x86_64.whl
), instead of jaxlib-0.1.55+tpu20200928-cp37-none-manylinux2010_x86_64.whl
.
Is jaxlib-0.1.57
TPU-compatible? I got jnp.dot(x, x).block_until_ready()
to work OK now, no RuntimeError
.
Hello, could you pls give advice about why TPU not detected by JAX in Kaggle notebook. Below is my code