jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.18k stars 2.76k forks source link

`id_print` throws `AttributeError: 'NoneType' object has no attribute 'add_outfeed'` #9053

Closed NightMachinery closed 4 months ago

NightMachinery commented 2 years ago

Please:

(I am using Colab.)

import jax.numpy as jnp
import jax
import jaxlib

print(jax.version.__version__)
print(jaxlib.version.__version__)

from jax.experimental.host_callback import id_print

@jax.jit
def t1():
  id_print(jnp.arange(4))

t1()

# Simply using `id_print(0)` is also enough to throw this exception!
0.2.25
0.1.71
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
/usr/lib/python3.7/runpy.py in _run_module_as_main(***failed resolving arguments***)
    192     return _run_code(code, main_globals, None,
--> 193                      "__main__", mod_spec)
    194 

41 frames
JaxStackTraceBeforeTransformation: AttributeError: 'NoneType' object has no attribute 'add_outfeed'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

UnfilteredStackTrace                      Traceback (most recent call last)
UnfilteredStackTrace: AttributeError: 'NoneType' object has no attribute 'add_outfeed'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py in _outside_call_translation_rule(ctx, avals_in, avals_out, has_token, identity, flat_results_aval, *args_op, **params)
    957             flat_results_aval=flat_results_aval,
    958             **params))
--> 959     next_token = _callback_handler_data.receiver.add_outfeed(comp, current_token,
    960                                                         callback_id,
    961                                                         args_to_outfeed)

AttributeError: 'NoneType' object has no attribute 'add_outfeed'

This same code snippet works fine on my laptop, though I don't know if it's because of the TPU or the different jaxlib version:

0.2.25
0.1.74
INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in re
gistry given worker: 
INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform
 with name: "cuda". Available platform names are: Interpreter Host
INFO:absl:Unable to initialize backend 'tpu': INVALID_ARGUMENT: TpuPlatform is not availabl
e.
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun f
or more info.)
[0 1 2 3]
NightMachinery commented 2 years ago

This bug does not occur on Colab when using the latest versions (0.2.26, 0.1.75). Feel free to close the issue, though I think you should upgrade the Colab's JAX version.

NightMachinery commented 2 years ago

The latest JAX versions don't seem to work with Colab's TPU though:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-25-05dddd30ebb4> in <module>()
----> 1 books_bytes = ([jnp.array(b['bytes']) for b in books.values()])

13 frames
/usr/local/lib/python3.7/dist-packages/jax/_src/lib/xla_bridge.py in _get_backend_uncached(platform)
    285     if backend is None:
    286       if platform in _backends_errors:
--> 287         raise RuntimeError(f"Backend '{platform}' failed to initialize: "
    288                            f"{_backends_errors[platform]}")
    289       raise RuntimeError(f"Unknown backend {platform}")

RuntimeError: Backend 'tpu_driver' failed to initialize: NOT_FOUND: Unable to find driver in registry given worker:
NightMachinery commented 2 years ago

The problem occurs with the latest versions on Colab's TPUv2.

!pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
import jax.numpy as jnp
import jax
import jaxlib

print(jax.version.__version__)
print(jaxlib.version.__version__)

from jax.experimental.host_callback import id_print

@jax.jit
def t1():
  id_print(jnp.arange(4))

t1()
0.2.26
0.1.75
---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
/usr/lib/python3.7/runpy.py in _run_module_as_main(***failed resolving arguments***)
    192     return _run_code(code, main_globals, None,
--> 193                      "__main__", mod_spec)
    194 

44 frames
JaxStackTraceBeforeTransformation: AttributeError: 'NoneType' object has no attribute 'add_outfeed'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

UnfilteredStackTrace                      Traceback (most recent call last)
UnfilteredStackTrace: AttributeError: 'NoneType' object has no attribute 'add_outfeed'

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py in _outside_call_translation_rule(ctx, avals_in, avals_out, has_token, identity, flat_results_aval, *args_op, **params)
    922             flat_results_aval=flat_results_aval,
    923             **params))
--> 924     next_token = _callback_handler_data.receiver.add_outfeed(comp, current_token,
    925                                                         callback_id,
    926                                                         args_to_outfeed)

AttributeError: 'NoneType' object has no attribute 'add_outfeed'
zhangqiaorjc commented 2 years ago

I couldn't repro the error, here's my colab https://colab.research.google.com/drive/1I61dvqk0TbRyQwFdc0CceW9ll9M7kklV?usp=sharing

I did this in the top cell to setup colab tpu.


import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()```
NightMachinery commented 2 years ago

@zhangqiaorjc I ran your Colab notebook, and I got the error.

https://colab.research.google.com/gist/batbone/4231c11e5b31269d25bb2040508ef9da/untitled0.ipynb

I also noticed that in your successful run, you seemed to have used the CPU:

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
zhangqiaorjc commented 2 years ago

@NightMachinary in my run, i used TPUs. See

jax.devices() [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

I can't really repro your error. It might be external colab issues.

Maybe @jakevdp and @skye know what's the difference why you ran the exact same code differently.

NightMachinery commented 2 years ago

@zhangqiaorjc @jakevdp @skye Any updates? I ran that notebook with another gmail account of mine, and the error was the same. Without id_print, debugging JAX is close to impossible.

dylanhmorris commented 2 years ago

A note for any fellow Numpyro users who find this issue by googling. This error arises when running Numpyro in Colab with parallelized MCMC chains (i.e. num_chains > 1 and parallelization enabled).

An apparent workaround is to set progress_bar = False when instantiating theMCMC object, as it seems that Numpyro's MCMC runners only call id_print for the purpose of displaying parallelized progress bars.

evanatyourservice commented 2 years ago

Getting this error when using host_callback.call in colab. Works with colab cpu or gpu, and elsewhere outside of colab, so I think it's a colab TPU specific problem. Kind of annoying since colab is nice for testing or even training TPU stuff, but call is an important part of my program. To reproduce:

def ptest(x):
    print(x)
    #print(device)
    #print(device.id)
    return x + 1

def test(x):
    return call(ptest, x, result_shape=x, call_with_device=False)

x = jnp.array([1,2,3], dtype=jnp.float32)
test(x)

Full stack trace:

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
[/usr/lib/python3.7/runpy.py](https://localhost:8080/#) in _run_module_as_main(***failed resolving arguments***)
    192     return _run_code(code, main_globals, None,
--> 193                      "__main__", mod_spec)
    194 

51 frames
[/usr/lib/python3.7/runpy.py](https://localhost:8080/#) in _run_code(***failed resolving arguments***)
     84                        __spec__ = mod_spec)
---> 85     exec(code, run_globals)
     86     return run_globals

[/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py](https://localhost:8080/#) in <module>()
     15     from ipykernel import kernelapp as app
---> 16     app.launch_new_instance()

[/usr/local/lib/python3.7/dist-packages/traitlets/config/application.py](https://localhost:8080/#) in launch_instance(***failed resolving arguments***)
    845         app.initialize(argv)
--> 846         app.start()
    847 

[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelapp.py](https://localhost:8080/#) in start(***failed resolving arguments***)
    498         try:
--> 499             self.io_loop.start()
    500         except KeyboardInterrupt:

[/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py](https://localhost:8080/#) in start(***failed resolving arguments***)
    131             asyncio.set_event_loop(self.asyncio_loop)
--> 132             self.asyncio_loop.run_forever()
    133         finally:

[/usr/lib/python3.7/asyncio/base_events.py](https://localhost:8080/#) in run_forever(***failed resolving arguments***)
    540             while True:
--> 541                 self._run_once()
    542                 if self._stopping:

[/usr/lib/python3.7/asyncio/base_events.py](https://localhost:8080/#) in _run_once(***failed resolving arguments***)
   1785             else:
-> 1786                 handle._run()
   1787         handle = None  # Needed to break cycles when an exception occurs.

[/usr/lib/python3.7/asyncio/events.py](https://localhost:8080/#) in _run(***failed resolving arguments***)
     87         try:
---> 88             self._context.run(self._callback, *self._args)
     89         except Exception as exc:

[/usr/local/lib/python3.7/dist-packages/tornado/platform/asyncio.py](https://localhost:8080/#) in _handle_events(***failed resolving arguments***)
    121         fileobj, handler_func = self.handlers[fd]
--> 122         handler_func(fileobj, events)
    123 

[/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py](https://localhost:8080/#) in null_wrapper(***failed resolving arguments***)
    299                 _state.contexts = cap_contexts[0]
--> 300                 return fn(*args, **kwargs)
    301             finally:

[/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py](https://localhost:8080/#) in _handle_events(***failed resolving arguments***)
    451             if zmq_events & zmq.POLLIN and self.receiving():
--> 452                 self._handle_recv()
    453                 if not self.socket:

[/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py](https://localhost:8080/#) in _handle_recv(***failed resolving arguments***)
    480                 callback = self._recv_callback
--> 481                 self._run_callback(callback, msg)
    482 

[/usr/local/lib/python3.7/dist-packages/zmq/eventloop/zmqstream.py](https://localhost:8080/#) in _run_callback(***failed resolving arguments***)
    430             # inside our blanket exception handler rather than outside.
--> 431             callback(*args, **kwargs)
    432         except Exception:

[/usr/local/lib/python3.7/dist-packages/tornado/stack_context.py](https://localhost:8080/#) in null_wrapper(***failed resolving arguments***)
    299                 _state.contexts = cap_contexts[0]
--> 300                 return fn(*args, **kwargs)
    301             finally:

[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py](https://localhost:8080/#) in dispatcher(***failed resolving arguments***)
    282             def dispatcher(msg):
--> 283                 return self.dispatch_shell(stream, msg)
    284             return dispatcher

[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py](https://localhost:8080/#) in dispatch_shell(***failed resolving arguments***)
    232             try:
--> 233                 handler(stream, idents, msg)
    234             except Exception:

[/usr/local/lib/python3.7/dist-packages/ipykernel/kernelbase.py](https://localhost:8080/#) in execute_request(***failed resolving arguments***)
    398         reply_content = self.do_execute(code, silent, store_history,
--> 399                                         user_expressions, allow_stdin)
    400 

[/usr/local/lib/python3.7/dist-packages/ipykernel/ipkernel.py](https://localhost:8080/#) in do_execute(***failed resolving arguments***)
    207         try:
--> 208             res = shell.run_cell(code, store_history=store_history, silent=silent)
    209         finally:

[/usr/local/lib/python3.7/dist-packages/ipykernel/zmqshell.py](https://localhost:8080/#) in run_cell(***failed resolving arguments***)
    536         self._last_traceback = None
--> 537         return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    538 

[/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py](https://localhost:8080/#) in run_cell(***failed resolving arguments***)
   2717                 has_raised = self.run_ast_nodes(code_ast.body, cell_name,
-> 2718                    interactivity=interactivity, compiler=compiler, result=result)
   2719 

[/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py](https://localhost:8080/#) in run_ast_nodes(***failed resolving arguments***)
   2827                 code = compiler(mod, cell_name, "single")
-> 2828                 if self.run_code(code, result):
   2829                     return True

[/usr/local/lib/python3.7/dist-packages/IPython/core/interactiveshell.py](https://localhost:8080/#) in run_code(***failed resolving arguments***)
   2881                 #rprint('Running code', repr(code_obj)) # dbg
-> 2882                 exec(code_obj, self.user_global_ns, self.user_ns)
   2883             finally:

[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in <module>()
     12 x = jnp.array([1,2,3], dtype=jnp.float32)
---> 13 test(x)

[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in test(***failed resolving arguments***)
      8 def test(x):
----> 9     return call(ptest, x, result_shape=x, call_with_device=False)
     10 

[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in call(***failed resolving arguments***)
    689   return _call(callback_func, arg, result_shape=result_shape,
--> 690                call_with_device=call_with_device, identity=False)
    691 

[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _call(***failed resolving arguments***)
    738     params["flat_results_aval"] = tuple(flat_results_aval)
--> 739   flat_results = outside_call_p.bind(*flat_args, **params)
    740   return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)

[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_impl(***failed resolving arguments***)
    959     # different threads.
--> 960     return dispatch.apply_primitive(outside_call_p, *args, **params)
    961 

JaxStackTraceBeforeTransformation: AttributeError: 'NoneType' object has no attribute 'add_outfeed'

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

AttributeError                            Traceback (most recent call last)
[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in <module>()
     11 
     12 x = jnp.array([1,2,3], dtype=jnp.float32)
---> 13 test(x)

[<ipython-input-44-eaef6fa9c89c>](https://localhost:8080/#) in test(x)
      7 
      8 def test(x):
----> 9     return call(ptest, x, result_shape=x, call_with_device=False)
     10 
     11 

[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in call(callback_func, arg, result_shape, call_with_device)
    688   """
    689   return _call(callback_func, arg, result_shape=result_shape,
--> 690                call_with_device=call_with_device, identity=False)
    691 
    692 

[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _call(callback_func, arg, result_shape, call_with_device, identity)
    737     params["result_treedef"] = result_treedef
    738     params["flat_results_aval"] = tuple(flat_results_aval)
--> 739   flat_results = outside_call_p.bind(*flat_args, **params)
    740   return result_treedef.unflatten(flat_results) if not identity else arg_treedef.unflatten(flat_results)
    741 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind(self, *args, **params)
    284     assert (not config.jax_enable_checks or
    285             all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 286     return self.bind_with_trace(find_top_trace(args), args, params)
    287 
    288   def bind_with_trace(self, trace, args, params):

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in bind_with_trace(self, trace, args, params)
    287 
    288   def bind_with_trace(self, trace, args, params):
--> 289     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    290     return map(full_lower, out) if self.multiple_results else full_lower(out)
    291 

[/usr/local/lib/python3.7/dist-packages/jax/core.py](https://localhost:8080/#) in process_primitive(self, primitive, tracers, params)
    609 
    610   def process_primitive(self, primitive, tracers, params):
--> 611     return primitive.impl(*tracers, **params)
    612 
    613   def process_call(self, primitive, f, tracers, params):

[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_impl(*args, **params)
    958     # It would be confusing to process a sequence "id_tap; while" in two
    959     # different threads.
--> 960     return dispatch.apply_primitive(outside_call_p, *args, **params)
    961 
    962 

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in apply_primitive(prim, *args, **params)
     91   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
     92   compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
---> 93                                         **params)
     94   return compiled_fun(*args)
     95 

[/usr/local/lib/python3.7/dist-packages/jax/_src/util.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    208         return f(*args, **kwargs)
    209       else:
--> 210         return cached(config._trace_context(), *args, **kwargs)
    211 
    212     wrapper.cache_clear = cached.cache_clear

[/usr/local/lib/python3.7/dist-packages/jax/_src/util.py](https://localhost:8080/#) in cached(_, *args, **kwargs)
    201     @functools.lru_cache(max_size)
    202     def cached(_, *args, **kwargs):
--> 203       return f(*args, **kwargs)
    204 
    205     @functools.wraps(f)

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in xla_primitive_callable(prim, *arg_specs, **params)
    110       return out,
    111   compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), device, None,
--> 112                                     prim.name, donated_invars, *arg_specs)
    113   if not prim.multiple_results:
    114     return lambda *args, **kw: compiled(*args, **kw)[0]

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in _xla_callable_uncached(fun, device, backend, name, donated_invars, *arg_specs)
    168                            donated_invars, *arg_specs):
    169   return lower_xla_callable(fun, device, backend, name, donated_invars,
--> 170                             *arg_specs).compile().unsafe_call
    171 
    172 _xla_callable = lu.cache(_xla_callable_uncached)

[/usr/local/lib/python3.7/dist-packages/jax/_src/profiler.py](https://localhost:8080/#) in wrapper(*args, **kwargs)
    204   def wrapper(*args, **kwargs):
    205     with TraceAnnotation(name, **decorator_kwargs):
--> 206       return func(*args, **kwargs)
    207     return wrapper
    208   return wrapper

[/usr/local/lib/python3.7/dist-packages/jax/_src/dispatch.py](https://localhost:8080/#) in lower_xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
    258     module = mlir.lower_jaxpr_to_module(
    259         module_name, closed_jaxpr, backend.platform,
--> 260         mlir.ReplicaAxisContext(axis_env), name_stack, donated_invars)
    261   else:
    262     module = xla.lower_jaxpr_to_xla_module(

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in lower_jaxpr_to_module(module_name, jaxpr, platform, axis_context, name_stack, donated_args, replicated_args, arg_shardings, result_shardings)
    492         replace_tokens_with_dummy=True, replicated_args=replicated_args,
    493         arg_shardings=arg_shardings, result_shardings=result_shardings,
--> 494         input_output_aliases=input_output_aliases)
    495 
    496   ctx.module.operation.verify()

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in lower_jaxpr_to_fun(ctx, name, jaxpr, public, replace_units_with_dummy, replace_tokens_with_dummy, replicated_args, arg_shardings, result_shardings, input_output_aliases)
    635     out_vals = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
    636                              jaxpr.jaxpr, map(ir_constants, jaxpr.consts),
--> 637                              *args)
    638     outs = []
    639     for aval, out in zip(jaxpr.out_avals, out_vals):

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in jaxpr_subcomp(ctx, jaxpr, consts, *args)
    721           avals_in=map(aval, eqn.invars), avals_out=map(aval, eqn.outvars))
    722       ans = rule(rule_ctx, *map(_unwrap_singleton_ir_values, in_nodes),
--> 723                  **eqn.params)
    724 
    725     try:

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in cached_lowering(ctx, *args, **params)
    934       return f(ctx, *args, **params)
    935     if func is None:
--> 936       func = _emit_lowering_rule_as_fun(partial(f, **params), ctx)
    937       ctx.module_context.cached_primitive_lowerings[key] = func
    938 

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in _emit_lowering_rule_as_fun(lowering_rule, ctx)
    664     unflattened_args = util.unflatten(entry_block.arguments,
    665                                       map(len, input_types))
--> 666     outs = lowering_rule(ctx, *_unwrap_singleton_ir_values(unflattened_args))
    667     func_dialect.ReturnOp(util.flatten(map(wrap_singleton_ir_values, outs)))
    668   return func_op

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/mlir.py](https://localhost:8080/#) in fallback(ctx, *args, **params)
    951     module_ctx = ctx.module_context
    952     xla_computation = xla.primitive_subcomputation(
--> 953         module_ctx.platform, module_ctx.axis_env, prim, *ctx.avals_in, **params)
    954     submodule_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation)
    955     submodule = ir.Module.parse(submodule_str)

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py](https://localhost:8080/#) in primitive_subcomputation(platform, axis_env, prim, *avals, **params)
    444   ctx = TranslationContext(builder=c, platform=platform, axis_env=axis_env,
    445                            name_stack=new_name_stack())
--> 446   ans = f(ctx.replace(builder=c), avals, None, *xla_args, **params)
    447   if prim.multiple_results:
    448     ans = xops.Tuple(c, ans)

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py](https://localhost:8080/#) in f_new(ctx, avals_in, avals_out, *xla_args, **params)
   1034         jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, avals_in)
   1035       return jaxpr_subcomp(ctx, jaxpr, _xla_consts(ctx.builder, consts),
-> 1036                            *xla_args)
   1037     return f_new
   1038 

[/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py](https://localhost:8080/#) in jaxpr_subcomp(ctx, jaxpr, consts, *args)
    610           config.jax_experimental_name_stack else ctx)
    611       ans = rule(eqn_ctx, map(aval, eqn.invars), map(aval, eqn.outvars),
--> 612                  *in_nodes, **eqn.params)
    613 
    614     assert isinstance(ans, collections.abc.Sequence), (ans, eqn)

[/usr/local/lib/python3.7/dist-packages/jax/experimental/host_callback.py](https://localhost:8080/#) in _outside_call_translation_rule(ctx, avals_in, avals_out, has_token, identity, flat_results_aval, *args_op, **params)
    996             flat_results_aval=flat_results_aval,
    997             **params))
--> 998     next_token = _callback_handler_data.receiver.add_outfeed(comp, current_token,
    999                                                         callback_id,
   1000                                                         args_to_outfeed)

AttributeError: 'NoneType' object has no attribute 'add_outfeed'
evanatyourservice commented 2 years ago

@zhangqiaorjc I ran your linked example notebook and got the exact same error I pasted above.

Edit: my coding partner also got the error, not sure why you don't get it when you run it?

NightMachinery commented 2 years ago

@evanatyourservice @zhangqiaorjc Any updates on this? This is quite essential in debugging a TPU model.

rajasekharporeddy commented 4 months ago

Hi @NightMachinery

jax.experimental.host_callback is deprecated in JAX 0.4.26 and it is suggested to use jax.debug.print instead of id_print (refer #20385). I tested this issue using jax.debug.print() on colab TPU v2 using JAX version 0.4.26 and it works without any error.

import jax.numpy as jnp
import jax
import jaxlib

print(jax.version.__version__)
print(jaxlib.version.__version__)
print(jax.devices())

@jax.jit
def t1():
  jax.debug.print("{}", jnp.arange(4))

t1()

Output:

0.4.26
0.4.26
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
[0 1 2 3]

Attaching the gist for reference.

Thank you.

jakevdp commented 4 months ago

Thanks for following up!