google-deepmind / reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
Apache License 2.0
703 stars 93 forks source link

PRE sample with episode for tfclient #69

Open cmarlin opened 2 years ago

cmarlin commented 2 years ago

Hello, I'm don't find how to store episodes data with a priority per position with the tfclient. Could you provide such example ? It would be valuable for pre on muzero algorithm Thanks a lot

fastturtle commented 2 years ago

Have you seen the documentation for TFClient.insert()? This should be analogous to the example of Client.insert(). Let us know if this is not clear so we can update it.

Additionally, are you using the TFClient instead of the Client and TrajectoryWriter for performance reasons? It is often simpler to use the TrajectoryWriter where possible.

cmarlin commented 2 years ago

Yes, I try to use TFClient for performance, as TrajectoryWriter is really slow. Due to performance I use batched environments and agent's policy, so it makes things complex. Here is a pseudo code of my listener:

class ReverbObserver(): ..def __init__(self, reverb_tfclient, collect_data_spec, batch_size:int): ....flat_collect_data_spec = tf.nest.flatten(collect_data_spec) ....self._writer = [tf.Variable(tf.zeros([batch_size] + s.shape, s.dtype)) for s in flat_collect_data_spec] ....self._writer_position = tf.Variable(tf.zeros([batch_size], dtype=tf.int32))

..def __call__(self, traj:tf_agents.trajectories.trajectory.Trajectory): ....flat_traj = tf.nest.flatten(traj) ....# append observation to internal buffer ....writer_indices2D = tf.stack([tf.range(self._batch_size), self._writer_position], -1) ....for writer_elt, traj_elt in zip(self._writer, flat_traj): ......writer_elt.assign(tf.tensor_scatter_nd_update(writer_elt, writer_indices2D, traj_elt)) ....self._writer_position.assign(self._writer_position + 1) ....# ok, check for finished episodes ....traj_indexes = tf.where(traj.is_last()) ....for traj_index in traj_indexes: ....common_data = [episode[traj_index] for episode in self._writer] ....for step in tf.range(self._writer_position[traj_index]): ....self._reverb_tfclient.insert( ......data = common_data + [step], ......tables = self._table_names, ......priorities = tf.constant([1.0], tf.float64), ......) ....self._writer_position.assign(tf.where(traj.is_last(), 0, self._writer_position))

Obviously, I still have issues at the moment (sampling probability is relative to sample, not episode, ... ). But I would like to know if "common_data" are shared internally by client/reverb, as it stores all the episode datas. For PER I need a priority per step. So It could be helpfull to add an efficient code sample for high throughput.

Thanks