microsoft / FQF

FQF(Fully parameterized Quantile Function for distributional reinforcement learning) is a general reinforcement learning framework for Atari games, which can learn to play Atari games automatically by predicting return distribution in the form of a fully parameterized quantile function.
Other
40 stars 10 forks source link

tf.gather_nd error #5

Open julio-cmdr opened 2 years ago

julio-cmdr commented 2 years ago

Hi!

I'm triyng to run FQF using the script run-fqf.sh, but I'm getting an error that I couldn't resolve. It only happens when the agent starts trainning.

I'm running the code using CPU and not GPU. Would be it the problem?

Thanks for your attention!

File "train.py", line 65, in <module>
    app.run(main)
[elided 14 identical lines from previous traceback]
File "../../dopamine/agents/dqn/dqn_agent.py", line 205, in __init__
    self._train_op = self._build_train_op()
File "../../dopamine/agents/fqf/fqf_agent.py", line 377, in _build_train_op
    chosen_action_L_tau = tf.gather_nd(self._replay_net_outputs.L_tau, reshaped_actions)
File "/home/julio/.local/lib/python3.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3647, in gather_nd
    "GatherNd", params=params, indices=indices, name=name)
File "/home/julio/.local/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
File "/home/julio/.local/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
File "/home/julio/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
    op_def=op_def)
File "/home/julio/.local/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1801, in __init__
    self._traceback = tf_stack.extract_stack()

InvalidArgumentError (see above for traceback): indices[31] = [31, 1] does not index into shape [31,32,9]
     [[node gradients_2/GatherNd_3_grad/ScatterNd (defined at ../../dopamine/agents/fqf/fqf_agent.py:410) ### ]]