google-deepmind / dqn_zoo

DQN Zoo is a collection of reference implementations of reinforcement learning agents developed at DeepMind based on the Deep Q-Network (DQN) agent.
Apache License 2.0
456 stars 78 forks source link

CUDA operation failed: device kernel image is invalid when running with gpu #9

Closed gridpower closed 3 years ago

gridpower commented 3 years ago

I get the following error when trying to run the run.sh script with gpu enabled:

./run.sh 

Successfully built b7ab61042bfc
Successfully tagged dqn_zoo:latest
Run DQN on GPU in a container named dqn_zoo_dqn
I1130 15:33:05.441911 139654470924096 run_atari.py:80] DQN on Atari on gpu.
I1130 15:33:07.016694 139654470924096 run_atari.py:103] Environment: pong
I1130 15:33:07.017373 139654470924096 run_atari.py:104] Action spec: DiscreteArray(shape=(), dtype=int32, name=action, minimum=0, maximum=5, num_values=6)
I1130 15:33:07.018801 139654470924096 run_atari.py:105] Observation spec: (Array(shape=(210, 160, 3), dtype=dtype('uint8'), name='rgb'), Array(shape=(), dtype=dtype('int32'), name='lives'))
Traceback (most recent call last):
  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/workspace/dqn_zoo/dqn/run_atari.py", line 255, in <module>
    app.run(main)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/usr/local/lib/python3.6/dist-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "/workspace/dqn_zoo/dqn/run_atari.py", line 172, in main
    train_rng_key, eval_rng_key = jax.random.split(rng_key)
  File "/usr/local/lib/python3.6/dist-packages/jax/random.py", line 267, in split
    return _split(key, num)
  File "/usr/local/lib/python3.6/dist-packages/jax/api.py", line 170, in f_jitted
    name=flat_fun.__name__, donated_invars=donated_invars)
  File "/usr/local/lib/python3.6/dist-packages/jax/core.py", line 1100, in call_bind
    outs = primitive.impl(fun, *args, **params)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 544, in _xla_call_impl
    return compiled_fun(*args)
  File "/usr/local/lib/python3.6/dist-packages/jax/interpreters/xla.py", line 775, in _execute_compiled
    out_bufs = compiled.execute(input_bufs)
RuntimeError: CUDA operation failed: device kernel image is invalid
Removing /tmp/dqn_zoo_20201130_173250_8Zxwwh

Things I've tried:

Mon Nov 30 15:37:27 2020
+-----------------------------------------------------------------------------+ | NVIDIA-SMI 450.80.02 Driver Version: 450.80.02 CUDA Version: 11.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 GeForce GTX 960M Off | 00000000:01:00.0 Off | N/A | | N/A 65C P0 N/A / N/A | 690MiB / 4043MiB | 13% Default | | | | N/A | +-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| +-----------------------------------------------------------------------------+

- attempted running the run script and also the RAINBOW agent without the docker container (to do this, I installed CUDA 10.1 and CuDNN 7.0 additionally to the other requirements)
and getting the same error

python3 dqn_zoo/rainbow/run_atari.py

I1130 16:52:19.897150 140224573081408 run_atari.py:97] Rainbow on Atari on gpu. I1130 16:52:20.733672 140224573081408 run_atari.py:120] Environment: pong I1130 16:52:20.734061 140224573081408 run_atari.py:121] Action spec: DiscreteArray(shape=(), dtype=int32, name=action, minimum=0, maximum=5, num_values=6) I1130 16:52:20.734654 140224573081408 run_atari.py:122] Observation spec: (Array(shape=(210, 160, 3), dtype=dtype('uint8'), name='rgb'), Array(shape=(), dtype=dtype('int32'), name='lives')) Traceback (most recent call last): File "dqn_zoo/rainbow/run_atari.py", line 277, in app.run(main) File "/home/user/.local/lib/python3.6/site-packages/absl/app.py", line 299, in run _run_main(main, args) File "/home/user/.local/lib/python3.6/site-packages/absl/app.py", line 250, in _run_main sys.exit(main(argv)) File "dqn_zoo/rainbow/run_atari.py", line 193, in main train_rng_key, eval_rng_key = jax.random.split(rng_key) File "/home/user/.local/lib/python3.6/site-packages/jax/random.py", line 267, in split return _split(key, num) File "/home/user/.local/lib/python3.6/site-packages/jax/api.py", line 170, in f_jitted name=flat_fun.name, donated_invars=donated_invars) File "/home/user/.local/lib/python3.6/site-packages/jax/core.py", line 1100, in call_bind outs = primitive.impl(fun, *args, *params) File "/home/user/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 544, in _xla_call_impl return compiled_fun(args) File "/home/user/.local/lib/python3.6/site-packages/jax/interpreters/xla.py", line 775, in _execute_compiled out_bufs = compiled.execute(input_bufs) RuntimeError: CUDA operation failed: device kernel image is invalid


- confirmed my NVIDIA drivers and CUDA version compatibility
- ran successfully a code sample from CUDA's toolkit (deviceQuery utility sample) to confirm that the problem is not from there

Do you have any ideas of why this is happening or what I could try? Running on CPU works without a problem.
GeorgOstrovski commented 3 years ago

There seem to be similar or related issues that were reported e.g. here - in fact, in that issue the same GPU as yours (GTX 960) seems to have been affected.

I'd suggest checking if you can run any JAX operations on GPU at all, if not, raise an issue with JAX.

gridpower commented 3 years ago

Indeed it seems like I can't run any JAX operations. Thanks for the help, I will check it with JAX further :)