google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

RuntimeError: CUDA operation failed: cudaGetErrorString symbol not found #85

Closed kventinel closed 3 years ago

kventinel commented 4 years ago

https://colab.research.google.com/drive/1VHzY55vHtMPsvXR302WoYYAJTj74jy1S?usp=sharing

from neural_tangents import stax
from jax import random
init_fn, apply_fn, kernel_fn = stax.Dense(out_dim=1)
INPUT_SHAPE = (-1, 3072)
key = random.PRNGKey(0)
_, params = init_fn(key, INPUT_SHAPE)

After that have error:

---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-3-7bb0b75087b5> in <module>()
      4 INPUT_SHAPE = (-1, 3072)
      5 key = random.PRNGKey(0)
----> 6 _, params = init_fn(key, INPUT_SHAPE)

9 frames

/usr/local/lib/python3.6/dist-packages/neural_tangents/stax.py in ntk_init_fn(rng, input_shape)
    746     output_shape = (input_shape[:_channel_axis] + (out_dim,)
    747                     + input_shape[_channel_axis + 1:])
--> 748     rng1, rng2 = random.split(rng)
    749     W = random.normal(rng1, (input_shape[_channel_axis], out_dim))
    750 

/usr/local/lib/python3.6/dist-packages/jax/random.py in split(key, num)
    284     An array with shape (num, 2) and dtype uint32 representing `num` new keys.
    285   """
--> 286   return _split(key, int(num))  # type: ignore
    287 
    288 @partial(jit, static_argnums=(1,))

/usr/local/lib/python3.6/dist-packages/jax/api.py in f_jitted(*args, **kwargs)
    369         return cache_miss(*args, **kwargs)[0]  # probably won't return
    370     else:
--> 371       return cpp_jitted_f(*args, **kwargs)
    372   f_jitted._cpp_jitted_f = cpp_jitted_f
    373 

/usr/local/lib/python3.6/dist-packages/jax/api.py in cache_miss(*args, **kwargs)
    282         backend=backend,
    283         name=flat_fun.__name__,
--> 284         donated_invars=donated_invars)
    285     out_pytree_def = out_tree()
    286     out = tree_unflatten(out_pytree_def, out_flat)

/usr/local/lib/python3.6/dist-packages/jax/core.py in bind(self, fun, *args, **params)
   1187 
   1188   def bind(self, fun, *args, **params):
-> 1189     return call_bind(self, fun, *args, **params)
   1190 
   1191   def process(self, trace, fun, tracers, params):

/usr/local/lib/python3.6/dist-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1178   tracers = map(top_trace.full_raise, args)
   1179   with maybe_new_sublevel(top_trace):
-> 1180     outs = primitive.process(top_trace, fun, tracers, params)
   1181   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1182 

/usr/local/lib/python3.6/dist-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1190 
   1191   def process(self, trace, fun, tracers, params):
-> 1192     return trace.process_call(self, fun, tracers, params)
   1193 
   1194   def post_process(self, trace, out_tracers, params):

/usr/local/lib/python3.6/dist-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    581 
    582   def process_call(self, primitive, f, tracers, params):
--> 583     return primitive.impl(f, *tracers, **params)
    584   process_map = process_call
    585 

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    561                                *unsafe_map(arg_spec, args))
    562   try:
--> 563     return compiled_fun(*args)
    564   except FloatingPointError:
    565     assert FLAGS.jax_debug_nans  # compiled_fun can only raise in this case

/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, *args)
    809   device, = compiled.local_devices()
    810   input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 811   out_bufs = compiled.execute(input_bufs)
    812   if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
    813   return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

RuntimeError: CUDA operation failed: cudaGetErrorString symbol not found.
romanngg commented 4 years ago

Thanks for the full repro colab! This seems to be more of a JAX issue with CUDA 11 colab runtime, e.g. the same error reproduces if you do just

from jax import random
a = random.split(random.PRNGKey(1), 2)

So I would suggest submitting this bug to https://github.com/google/jax

In the meantime, if CUDA 10 is OK, I believe just running

!pip install -q git+https://www.github.com/google/neural-tangents

should work.

maryamag85 commented 3 years ago

I have an issue running jax.random.PRNGKey(0) and I can not downgrade cuda to 10 as it is incompatible with my nvidia driver. So what should I do in this case ? any workaround?

romanngg commented 3 years ago

Could you be using an older JAX version by any chance? JAX seems to support CUDA 11: https://github.com/google/jax#pip-installation

romanngg commented 3 years ago

This should no longer be an issue since Colab now has CUDA 11, see working example: https://colab.research.google.com/gist/romanngg/6ad518bfafeeceda1e41bb54f043fbf9/no-jax-error-with-cuda-11.ipynb

Please feel free to reopen if there are still other problems!