Closed NightMachinery closed 4 months ago
This bug does not occur on Colab when using the latest versions (0.2.26, 0.1.75). Feel free to close the issue, though I think you should upgrade the Colab's JAX version.
The latest JAX versions don't seem to work with Colab's TPU though:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-25-05dddd30ebb4> in <module>()
----> 1 books_bytes = ([jnp.array(b['bytes']) for b in books.values()])
13 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/lib/xla_bridge.py in _get_backend_uncached(platform)
285 if backend is None:
286 if platform in _backends_errors:
--> 287 raise RuntimeError(f"Backend '{platform}' failed to initialize: "
288 f"{_backends_errors[platform]}")
289 raise RuntimeError(f"Unknown backend {platform}")
RuntimeError: Backend 'tpu_driver' failed to initialize: NOT_FOUND: Unable to find driver in registry given worker:
The problem occurs with the latest versions on Colab's TPUv2.
!pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
import jax.numpy as jnp
import jax
import jaxlib
print(jax.version.__version__)
print(jaxlib.version.__version__)
from jax.experimental.host_callback import id_print
@jax.jit
def t1():
id_print(jnp.arange(4))
t1()
0.2.26
0.1.75
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
/usr/lib/python3.7/runpy.py in _run_module_as_main(***failed resolving arguments***)
192 return _run_code(code, main_globals, None,
--> 193 "__main__", mod_spec)
194
44 frames
JaxStackTraceBeforeTransformation: AttributeError: 'NoneType' object has no attribute 'add_outfeed'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
UnfilteredStackTrace Traceback (most recent call last)
UnfilteredStackTrace: AttributeError: 'NoneType' object has no attribute 'add_outfeed'
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py in _outside_call_translation_rule(ctx, avals_in, avals_out, has_token, identity, flat_results_aval, *args_op, **params)
922 flat_results_aval=flat_results_aval,
923 **params))
--> 924 next_token = _callback_handler_data.receiver.add_outfeed(comp, current_token,
925 callback_id,
926 args_to_outfeed)
AttributeError: 'NoneType' object has no attribute 'add_outfeed'
I couldn't repro the error, here's my colab https://colab.research.google.com/drive/1I61dvqk0TbRyQwFdc0CceW9ll9M7kklV?usp=sharing
I did this in the top cell to setup colab tpu.
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()```
@zhangqiaorjc I ran your Colab notebook, and I got the error.
https://colab.research.google.com/gist/batbone/4231c11e5b31269d25bb2040508ef9da/untitled0.ipynb
I also noticed that in your successful run, you seemed to have used the CPU:
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
@NightMachinary in my run, i used TPUs. See
jax.devices() [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
I can't really repro your error. It might be external colab issues.
Maybe @jakevdp and @skye know what's the difference why you ran the exact same code differently.
@zhangqiaorjc @jakevdp @skye Any updates? I ran that notebook with another gmail account of mine, and the error was the same. Without id_print
, debugging JAX is close to impossible.
A note for any fellow Numpyro users who find this issue by googling. This error arises when running Numpyro in Colab with parallelized MCMC chains (i.e. num_chains > 1
and parallelization enabled).
An apparent workaround is to set progress_bar = False
when instantiating theMCMC
object, as it seems that Numpyro's MCMC runners only call id_print
for the purpose of displaying parallelized progress bars.
Getting this error when using host_callback.call
in colab. Works with colab cpu or gpu, and elsewhere outside of colab, so I think it's a colab TPU specific problem. Kind of annoying since colab is nice for testing or even training TPU stuff, but call
is an important part of my program. To reproduce:
def ptest(x):
print(x)
#print(device)
#print(device.id)
return x + 1
def test(x):
return call(ptest, x, result_shape=x, call_with_device=False)
x = jnp.array([1,2,3], dtype=jnp.float32)
test(x)
Full stack trace:
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation Traceback (most recent call last)
[/usr/lib/python3.7/runpy.py](https://localhost:8080/#) in _run_module_as_main(***failed resolving arguments***)
192 return _run_code(code, main_globals, None,
--> 193 "__main__", mod_spec)
194
51 frames
[/usr/lib/python3.7/runpy.py](https://localhost:8080/#) in _run_code(***failed resolving arguments***)
84 __spec__ = mod_spec)
---> 85 exec(code, run_globals)
86 return run_globals
[/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py](https://localhost:8080/#) in <module>()
15 from ipykernel import kernelapp as app
---> 16 app.launch_new_instance()
[/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py](https://localhost:8080/#) in launch_instance(***failed resolving arguments***)
845 app.initialize(argv)
--> 846 app.start()
847
[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py](https://localhost:8080/#) in start(***failed resolving arguments***)
498 try:
--> 499 self.io_loop.start()
500 except KeyboardInterrupt:
[/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py](https://localhost:8080/#) in start(***failed resolving arguments***)
131 asyncio.set_event_loop(self.asyncio_loop)
--> 132 self.asyncio_loop.run_forever()
133 finally:
[/usr/lib/python3.7/asyncio/base_events.py](https://localhost:8080/#) in run_forever(***failed resolving arguments***)
540 while True:
--> 541 self._run_once()
542 if self._stopping:
[/usr/lib/python3.7/asyncio/base_events.py](https://localhost:8080/#) in _run_once(***failed resolving arguments***)
1785 else:
-> 1786 handle._run()
1787 handle = None # Needed to break cycles when an exception occurs.
[/usr/lib/python3.7/asyncio/events.py](https://localhost:8080/#) in _run(***failed resolving arguments***)
87 try:
---> 88 self._context.run(self._callback, *self._args)
89 except Exception as exc:
[/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py](https://localhost:8080/#) in _handle_events(***failed resolving arguments***)
121 fileobj, handler_func = self.handlers[fd]
--> 122 handler_func(fileobj, events)
123
[/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py](https://localhost:8080/#) in null_wrapper(***failed resolving arguments***)
299 _state.contexts = cap_contexts[0]
--> 300 return fn(*args, **kwargs)
301 finally:
[/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py](https://localhost:8080/#) in _handle_events(***failed resolving arguments***)
451 if zmq_events & zmq.POLLIN and self.receiving():
--> 452 self._handle_recv()
453 if not self.socket:
[/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py](https://localhost:8080/#) in _handle_recv(***failed resolving arguments***)
480 callback = self._recv_callback
--> 481 self._run_callback(callback, msg)
482
[/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py](https://localhost:8080/#) in _run_callback(***failed resolving arguments***)
430 # inside our blanket exception handler rather than outside.
--> 431 callback(*args, **kwargs)
432 except Exception:
[/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py](https://localhost:8080/#) in null_wrapper(***failed resolving arguments***)
299 _state.contexts = cap_contexts[0]
--> 300 return fn(*args, **kwargs)
301 finally:
[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py](https://localhost:8080/#) in dispatcher(***failed resolving arguments***)
282 def dispatcher(msg):
--> 283 return self.dispatch_shell(stream, msg)
284 return dispatcher
[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py](https://localhost:8080/#) in dispatch_shell(***failed resolving arguments***)
232 try:
--> 233 handler(stream, idents, msg)
234 except Exception:
[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py](https://localhost:8080/#) in execute_request(***failed resolving arguments***)
398 reply_content = self.do_execute(code, silent, store_history,
--> 399 user_expressions, allow_stdin)
400
[/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py](https://localhost:8080/#) in do_execute(***failed resolving arguments***)
207 try:
--> 208 res = shell.run_cell(code, store_history=store_history, silent=silent)
209 finally:
[/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py](https://localhost:8080/#) in run_cell(***failed resolving arguments***)
536 self._last_traceback = None
--> 537 return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
538
[/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py](https://localhost:8080/#) in run_cell(***failed resolving arguments***)
2717 has_raised = self.run_ast_nodes(code_ast.body, cell_name,
-> 2718 interactivity=interactivity, compiler=compiler, result=result)
2719
[/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py](https://localhost:8080/#) in run_ast_nodes(***failed resolving arguments***)
2827 code = compiler(mod, cell_name, "single")
-> 2828 if self.run_code(code, result):
2829 return True
[/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py](https://localhost:8080/#) in run_code(***failed resolving arguments***)
2881 #rprint('Running code', repr(code_obj)) # dbg
-> 2882 exec(code_obj, self.user_global_ns, self.user_ns)
2883 finally:
[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in <module>()
12 x = jnp.array([1,2,3], dtype=jnp.float32)
---> 13 test(x)
[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in test(***failed resolving arguments***)
8 def test(x):
----> 9 return call(ptest, x, result_shape=x, call_with_device=False)
10
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in call(***failed resolving arguments***)
689 return _call(callback_func, arg, result_shape=result_shape,
--> 690 call_with_device=call_with_device, identity=False)
691
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _call(***failed resolving arguments***)
738 params["flat_results_aval"] = tuple(flat_results_aval)
--> 739 flat_results = outside_call_p.bind(*flat_args, **params)
740 return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_impl(***failed resolving arguments***)
959 # different threads.
--> 960 return dispatch.apply_primitive(outside_call_p, *args, **params)
961
JaxStackTraceBeforeTransformation: AttributeError: 'NoneType' object has no attribute 'add_outfeed'
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
AttributeError Traceback (most recent call last)
[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in <module>()
11
12 x = jnp.array([1,2,3], dtype=jnp.float32)
---> 13 test(x)
[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in test(x)
7
8 def test(x):
----> 9 return call(ptest, x, result_shape=x, call_with_device=False)
10
11
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in call(callback_func, arg, result_shape, call_with_device)
688 """
689 return _call(callback_func, arg, result_shape=result_shape,
--> 690 call_with_device=call_with_device, identity=False)
691
692
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _call(callback_func, arg, result_shape, call_with_device, identity)
737 params["result_treedef"] = result_treedef
738 params["flat_results_aval"] = tuple(flat_results_aval)
--> 739 flat_results = outside_call_p.bind(*flat_args, **params)
740 return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
741
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params)
284 assert (not config.jax_enable_checks or
285 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 286 return self.bind_with_trace(find_top_trace(args), args, params)
287
288 def bind_with_trace(self, trace, args, params):
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
287
288 def bind_with_trace(self, trace, args, params):
--> 289 out = trace.process_primitive(self, map(trace.full_raise, args), params)
290 return map(full_lower, out) if self.multiple_results else full_lower(out)
291
[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
609
610 def process_primitive(self, primitive, tracers, params):
--> 611 return primitive.impl(*tracers, **params)
612
613 def process_call(self, primitive, f, tracers, params):
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_impl(*args, **params)
958 # It would be confusing to process a sequence "id_tap; while" in two
959 # different threads.
--> 960 return dispatch.apply_primitive(outside_call_p, *args, **params)
961
962
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
91 """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
92 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
---> 93 **params)
94 return compiled_fun(*args)
95
[/usr/local/lib/python3.7/dist-packages/jax/_src/util.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
208 return f(*args, **kwargs)
209 else:
--> 210 return cached(config._trace_context(), *args, **kwargs)
211
212 wrapper.cache_clear = cached.cache_clear
[/usr/local/lib/python3.7/dist-packages/jax/_src/util.py](https://localhost:8080/#) in cached(_, *args, **kwargs)
201 @functools.lru_cache(max_size)
202 def cached(_, *args, **kwargs):
--> 203 return f(*args, **kwargs)
204
205 @functools.wraps(f)
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in xla_primitive_callable(prim, *arg_specs, **params)
110 return out,
111 compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
--> 112 prim.name, donated_invars, *arg_specs)
113 if not prim.multiple_results:
114 return lambda *args, **kw: compiled(*args, **kw)[0]
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, *arg_specs)
168 donated_invars, *arg_specs):
169 return lower_xla_callable(fun, device, backend, name, donated_invars,
--> 170 *arg_specs).compile().unsafe_call
171
172 _xla_callable = lu.cache(_xla_callable_uncached)
[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
204 def wrapper(*args, **kwargs):
205 with TraceAnnotation(name, **decorator_kwargs):
--> 206 return func(*args, **kwargs)
207 return wrapper
208 return wrapper
[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
258 module = mlir.lower_jaxpr_to_module(
259 module_name, closed_jaxpr, backend.platform,
--> 260 mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
261 else:
262 module = xla.lower_jaxpr_to_xla_module(
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in lower_jaxpr_to_module(module_name, jaxpr, platform, axis_context, name_stack, donated_args, replicated_args, arg_shardings, result_shardings)
492 replace_tokens_with_dummy=True, replicated_args=replicated_args,
493 arg_shardings=arg_shardings, result_shardings=result_shardings,
--> 494 input_output_aliases=input_output_aliases)
495
496 ctx.module.operation.verify()
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in lower_jaxpr_to_fun(ctx, name, jaxpr, public, replace_units_with_dummy, replace_tokens_with_dummy, replicated_args, arg_shardings, result_shardings, input_output_aliases)
635 out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
636 jaxpr.jaxpr, map(ir_constants, jaxpr.consts),
--> 637 *args)
638 outs = []
639 for aval, out in zip(jaxpr.out_avals, out_vals):
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in jaxpr_subcomp(ctx, jaxpr, consts, *args)
721 avals_in=map(aval, eqn.invars), avals_out=map(aval, eqn.outvars))
722 ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
--> 723 **eqn.params)
724
725 try:
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in cached_lowering(ctx, *args, **params)
934 return f(ctx, *args, **params)
935 if func is None:
--> 936 func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
937 ctx.module_context.cached_primitive_lowerings[key] = func
938
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in _emit_lowering_rule_as_fun(lowering_rule, ctx)
664 unflattened_args = util.unflatten(entry_block.arguments,
665 map(len, input_types))
--> 666 outs = lowering_rule(ctx, *_unwrap_singleton_ir_values(unflattened_args))
667 func_dialect.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs)))
668 return func_op
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in fallback(ctx, *args, **params)
951 module_ctx = ctx.module_context
952 xla_computation = xla.primitive_subcomputation(
--> 953 module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params)
954 submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
955 submodule = ir.Module.parse(submodule_str)
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py](https://localhost:8080/#) in primitive_subcomputation(platform, axis_env, prim, *avals, **params)
444 ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
445 name_stack=new_name_stack())
--> 446 ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params)
447 if prim.multiple_results:
448 ans = xops.Tuple(c, ans)
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py](https://localhost:8080/#) in f_new(ctx, avals_in, avals_out, *xla_args, **params)
1034 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
1035 return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
-> 1036 *xla_args)
1037 return f_new
1038
[/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py](https://localhost:8080/#) in jaxpr_subcomp(ctx, jaxpr, consts, *args)
610 config.jax_experimental_name_stack else ctx)
611 ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
--> 612 *in_nodes, **eqn.params)
613
614 assert isinstance(ans, collections.abc.Sequence), (ans, eqn)
[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_translation_rule(ctx, avals_in, avals_out, has_token, identity, flat_results_aval, *args_op, **params)
996 flat_results_aval=flat_results_aval,
997 **params))
--> 998 next_token = _callback_handler_data.receiver.add_outfeed(comp, current_token,
999 callback_id,
1000 args_to_outfeed)
AttributeError: 'NoneType' object has no attribute 'add_outfeed'
@zhangqiaorjc I ran your linked example notebook and got the exact same error I pasted above.
Edit: my coding partner also got the error, not sure why you don't get it when you run it?
@evanatyourservice @zhangqiaorjc Any updates on this? This is quite essential in debugging a TPU model.
Hi @NightMachinery
jax.experimental.host_callback
is deprecated in JAX 0.4.26 and it is suggested to use jax.debug.print
instead of id_print
(refer #20385). I tested this issue using jax.debug.print()
on colab TPU v2 using JAX version 0.4.26 and it works without any error.
import jax.numpy as jnp
import jax
import jaxlib
print(jax.version.__version__)
print(jaxlib.version.__version__)
print(jax.devices())
@jax.jit
def t1():
jax.debug.print("{}", jnp.arange(4))
t1()
Output:
0.4.26
0.4.26
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
[0 1 2 3]
Attaching the gist for reference.
Thank you.
Thanks for following up!
Please:
(I am using Colab.)
This same code snippet works fine on my laptop, though I don't know if it's because of the TPU or the different jaxlib version: