openai / baselines

OpenAI Baselines: high-quality implementations of reinforcement learning algorithms
MIT License
15.63k stars 4.86k forks source link

Pretrained Atari agent fails with tensorflow error #200

Open smc77 opened 6 years ago

smc77 commented 6 years ago

I get a tensorflow error when trying to load any saved models, following the example in the DQN README:

python -m baselines.deepq.experiments.atari.download_model --blob model-atari-duel-pong-1 --model-dir /tmp/models
python -m baselines.deepq.experiments.atari.enjoy --model-dir /tmp/models/model-atari-duel-pong-1 --env Pong --dueling

To run the example, I had to change references to deprecated function: https://github.com/openai/baselines/pull/199/commits/c5c4b2847f2ba64e1ae840ad70edcca916055f96

Gives the error: Assign requires shapes of both tensors to match. lhs shape= [8,8,1,32] rhs shape= [8,8,4,32]

$ python -m baselines.deepq.experiments.atari.enjoy --model-dir /tmp/models/model-atari-duel-pong-1 --env Pong --dueling
Traceback (most recent call last):
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1022, in _do_call
    return fn(*args)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1004, in _run_fn
    status, run_metadata)
  File "/Users/smc77/anaconda/lib/python3.6/contextlib.py", line 89, in __exit__
    next(self.gen)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 469, in raise_exception_on_not_ok_status
    pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [8,8,1,32] rhs shape= [8,8,4,32]
     [[Node: save/Assign_6 = Assign[T=DT_FLOAT, _class=["loc:@deepq/q_func/convnet/Conv/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](deepq/q_func/convnet/Conv/weights, save/RestoreV2_6)]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/smc77/anaconda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/smc77/anaconda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/smc77/src/openai/baselines/baselines/deepq/experiments/atari/enjoy.py", line 69, in <module>
    U.load_state(os.path.join(args.model_dir, "saved"))
  File "/Users/smc77/src/openai/baselines/baselines/common/tf_util.py", line 241, in load_state
    saver.restore(get_session(), fname)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1439, in restore
    {self.saver_def.filename_tensor_name: save_path})
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 767, in run
    run_metadata_ptr)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 965, in _run
    feed_dict_string, options, run_metadata)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
    target_list, options, run_metadata)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [8,8,1,32] rhs shape= [8,8,4,32]
     [[Node: save/Assign_6 = Assign[T=DT_FLOAT, _class=["loc:@deepq/q_func/convnet/Conv/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](deepq/q_func/convnet/Conv/weights, save/RestoreV2_6)]]

Caused by op 'save/Assign_6', defined at:
  File "/Users/smc77/anaconda/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/smc77/anaconda/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/smc77/src/openai/baselines/baselines/deepq/experiments/atari/enjoy.py", line 69, in <module>
    U.load_state(os.path.join(args.model_dir, "saved"))
  File "/Users/smc77/src/openai/baselines/baselines/common/tf_util.py", line 240, in load_state
    saver = tf.train.Saver()
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1051, in __init__
    self.build()
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 1081, in build
    restore_sequentially=self._restore_sequentially)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 675, in build
    restore_sequentially, reshape)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 414, in _AddRestoreOps
    assign_ops.append(saveable.restore(tensors, shapes))
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/training/saver.py", line 155, in restore
    self.op.get_shape().is_fully_defined())
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/ops/gen_state_ops.py", line 47, in assign
    use_locking=use_locking, name=name)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
    op_def=op_def)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/Users/smc77/anaconda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [8,8,1,32] rhs shape= [8,8,4,32]
     [[Node: save/Assign_6 = Assign[T=DT_FLOAT, _class=["loc:@deepq/q_func/convnet/Conv/weights"], use_locking=true, validate_shape=true, _device="/job:localhost/replica:0/task:0/cpu:0"](deepq/q_func/convnet/Conv/weights, save/RestoreV2_6)]]
BNSneha commented 6 years ago

Same issue