jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

RuntimeError: Internal: Non-root tuple types are not handled. #5773

Closed jwnys closed 3 years ago

jwnys commented 3 years ago

I'm currently having an issue with jax (jaxlib 0.1.61), when using GPU with cuda 10.1 (I don't get this error when using only CPU). I can't produce an MWA at the moment. I can't find any info on this error.. Can anyone give me some pointers on which type of behaviour might generate this error?

Traceback (most recent call last):
  File "/project/clebschnet/Examples/hparams.py", line 549, in <module>
    run_opt(args)
  File "/project/clebschnet/Examples/hparams.py", line 348, in run_opt
    ma = JaxClebschTreeTrax(hi, D, 1/2, depth,
  File "../clebschnet/machine.py", line 234, in JaxClebschTreeTrax
    return Trax(
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 628, in __init__
    self.jax_init_parameters(rescale=self.rescale)
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 657, in jax_init_parameters
    self.rescale_weights()
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 682, in rescale_weights
    self.eval_in_batches(tree_unflatten(flat_ptree, flat_params), self._state, states)
  File "/opt/conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 115, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 839, in eval_in_batches
    ch_out, new_state = self._forward_fn_with_state(params, state, ch, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 528, in <lambda>
    self._forward_fn_with_state = jax.jit(lambda pars, state, x: self._forward_fn_t(pars, state, x))
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 517, in <lambda>
    self._forward_fn_t = lambda pars, state, x: net.pure_fn(x, pars, state, None) # returns (Tensor, state)
  File "/opt/conda/lib/python3.8/site-packages/trax/layers/base.py", line 548, in pure_fn
    raise LayerError(name, 'pure_fn',
jax._src.traceback_util.FilteredStackTrace: trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/../clebschnet/layers_trax_tree.py, line 193
  layer input shapes: ShapeDtype{shape:(64, 64), dtype:float64}

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer InputTransform (in pure_fn):
  layer created in file [...]/../clebschnet/layers_trax_tree.py, line 406
  layer input shapes: ShapeDtype{shape:(64, 64), dtype:float64}

  File [...]/../clebschnet/layers_trax_tree.py, line 433, in forward
    x = self.vsmallest_lexicographically(tempx, n_sites).reshape((-1,1) + self.og_lattice)

  File [...]/jax/_src/traceback_util.py, line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File [...]/site-packages/jax/api.py, line 1222, in batched_fun
    out_flat = batching.batch(

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/../clebschnet/layers_trax_tree.py, line 461, in smallest_lexicographically
    return x[jnp.lexsort(sorted_x)[0]]

  File [...]/_src/numpy/lax_numpy.py, line 3932, in lexsort
    return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1]

  File [...]/_src/lax/lax.py, line 1414, in sort
    return tuple(sort_p.bind(*operand, dimension=dimension,

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/batching.py, line 151, in process_primitive
    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)

  File [...]/_src/lax/lax.py, line 5782, in _sort_batch_rule
    return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/site-packages/jax/core.py, line 628, in process_primitive
    return primitive.impl(*tracers, **params)

  File [...]/jax/interpreters/xla.py, line 238, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)

  File [...]/jax/_src/util.py, line 198, in wrapper
    return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)

  File [...]/jax/_src/util.py, line 191, in cached
    return f(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 288, in xla_primitive_callable
    compiled = backend_compile(backend, built_c, options)

  File [...]/jax/interpreters/xla.py, line 352, in backend_compile
    return backend.compile(built_c, compile_options=options)

RuntimeError: Internal: Non-root tuple types are not handled.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/project/clebschnet/Examples/hparams.py", line 549, in <module>
    run_opt(args)
  File "/project/clebschnet/Examples/hparams.py", line 348, in run_opt
    ma = JaxClebschTreeTrax(hi, D, 1/2, depth,
  File "../clebschnet/machine.py", line 234, in JaxClebschTreeTrax
    return Trax(
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 628, in __init__
    self.jax_init_parameters(rescale=self.rescale)
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 657, in jax_init_parameters
    self.rescale_weights()
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 682, in rescale_weights
    self.eval_in_batches(tree_unflatten(flat_ptree, flat_params), self._state, states)
  File "/opt/conda/lib/python3.8/site-packages/jax/_src/profiler.py", line 115, in wrapper
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 839, in eval_in_batches
    ch_out, new_state = self._forward_fn_with_state(params, state, ch, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/jax/api.py", line 396, in f_jitted
    return cpp_jitted_f(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 528, in <lambda>
    self._forward_fn_with_state = jax.jit(lambda pars, state, x: self._forward_fn_t(pars, state, x))
  File "/opt/conda/lib/python3.8/site-packages/netket/machine/jax.py", line 517, in <lambda>
    self._forward_fn_t = lambda pars, state, x: net.pure_fn(x, pars, state, None) # returns (Tensor, state)
  File "/opt/conda/lib/python3.8/site-packages/trax/layers/base.py", line 548, in pure_fn
    raise LayerError(name, 'pure_fn',
trax.layers.base.LayerError: Exception passing through layer Serial (in pure_fn):
  layer created in file [...]/../clebschnet/layers_trax_tree.py, line 193
  layer input shapes: ShapeDtype{shape:(64, 64), dtype:float64}

  File [...]/trax/layers/combinators.py, line 88, in forward
    outputs, s = layer.pure_fn(inputs, w, s, rng, use_cache=True)

LayerError: Exception passing through layer InputTransform (in pure_fn):
  layer created in file [...]/../clebschnet/layers_trax_tree.py, line 406
  layer input shapes: ShapeDtype{shape:(64, 64), dtype:float64}

  File [...]/../clebschnet/layers_trax_tree.py, line 433, in forward
    x = self.vsmallest_lexicographically(tempx, n_sites).reshape((-1,1) + self.og_lattice)

  File [...]/jax/_src/traceback_util.py, line 139, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)

  File [...]/site-packages/jax/api.py, line 1222, in batched_fun
    out_flat = batching.batch(

  File [...]/site-packages/jax/linear_util.py, line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))

  File [...]/../clebschnet/layers_trax_tree.py, line 461, in smallest_lexicographically
    return x[jnp.lexsort(sorted_x)[0]]

  File [...]/_src/numpy/lax_numpy.py, line 3932, in lexsort
    return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1]

  File [...]/_src/lax/lax.py, line 1414, in sort
    return tuple(sort_p.bind(*operand, dimension=dimension,

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/jax/interpreters/batching.py, line 151, in process_primitive
    val_out, dim_out = batched_primitive(vals_in, dims_in, **params)

  File [...]/_src/lax/lax.py, line 5782, in _sort_batch_rule
    return (sort_p.bind(*new_args, dimension=new_dimension, is_stable=is_stable, num_keys=num_keys),

  File [...]/site-packages/jax/core.py, line 282, in bind
    out = top_trace.process_primitive(self, tracers, params)

  File [...]/site-packages/jax/core.py, line 628, in process_primitive
    return primitive.impl(*tracers, **params)

  File [...]/jax/interpreters/xla.py, line 238, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)

  File [...]/jax/_src/util.py, line 198, in wrapper
    return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)

  File [...]/jax/_src/util.py, line 191, in cached
    return f(*args, **kwargs)

  File [...]/jax/interpreters/xla.py, line 288, in xla_primitive_callable
    compiled = backend_compile(backend, built_c, options)

  File [...]/jax/interpreters/xla.py, line 352, in backend_compile
    return backend.compile(built_c, compile_options=options)

RuntimeError: Internal: Non-root tuple types are not handled.
hawkinsp commented 3 years ago

This error originates inside XLA, but I'm actually unsure how to produce it. Can you share either a reproduction or an HLO dump?

You can get an HLO dump by running JAX with the environment variable XLA_FLAGS=--xla_dump_to=/tmp/somewhere and zipping up the files you get out. (Note this does essentially share your model's code on Github, if that matters.)

hawkinsp commented 3 years ago

Any updates? Is there a way for me to reproduce this?

hawkinsp commented 3 years ago

Closing. I suspect the bug has already been fixed in newer jaxlibs, and without a way to reproduce it there's nothing we can do. Please reopen if there's a way for us to reproduce the issue at head!