tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.81k stars 720 forks source link

Better warnings for incorrect batch size in replay buffer #377

Open djl11 opened 4 years ago

djl11 commented 4 years ago

Here is a minimal example:

from tf_agents.replay_buffers import tf_uniform_replay_buffer
import tensorflow as tf

batch_size = 8
counter = 0

data_spec = {'observations': tf.TensorSpec([3], tf.float32, 'obs'),
             'actions': tf.TensorSpec([3], tf.float32, 'actions')}

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=data_spec, batch_size=1, max_length=1000)

while True:
    print(counter)
    counter += 1
    replay_buffer.add_batch({'observations': tf.zeros((batch_size, 3)), 'actions': tf.zeros((batch_size, 3))})

This sometimes runs properly, and other times leads to:

Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)

This is hard to debug, and the batch size mismatch could be hard to spot in larger project files. I think an error should be thrown if you try add add an incorrectly sized batch :)

leroidauphin commented 4 years ago

This looks like a really interesting problem, so I attempted to reproduce it to see if it could be fixed. I can reproduce this in tf-agents 0.5.0, exactly as you describe, but I cannot reproduce this in the nightly build. Are you still seeing this SIGSEGV when using the nightly build?