Open joeryjoery opened 1 year ago
Another quick fix that I use is to wrap the reverb dataset inside the following class:
class RefreshIterator:
"""tf.data.Dataset fix for slow reverb client synchronization. Wrap around reverb-dataset."""
__slots__ = ["_iterable"]
def __init__(self, iterable):
self._iterable = iterable
def __iter__(self):
return self
def __next__(self):
return next(iter(self._iterable))
def next(self):
return self.__next__()
Use:
dataset = datasets.make_reverb_dataset(
table=my_table.name, server_address=reverb_client.server_address, batch_size=..., ...
)
jax_dataset = utils.multi_device_put(_NumpyIterator(RefreshIterator(dataset)), ...)
With unfortunately _NumpyIterator
a private class in tf.dataset_ops
.
Hi, I accidentally stumbled upon a problem within the tutorial notebook when playing around with the acme and reverb API that causes a weird synchronization behaviour between sampling from the
reverb
table and updating priorities. Another artifact of this that I encountered is that the very first transition would be consistently repeated until some hidden tensorflow buffer would be flushed.What I found is that when I would mutate the priorities in a
reverb
table usingclient.mutate_priorities(table_name, my_dict)
and then create an iterator from thetf.data.Dataset
object, then the priorities would update only after flushing a large number of samples. In contrast, if I didn't convert thetf.data.Dataset
to an iterator and used thedataset.batch(n); dataset.take(n)
interface, it would immediately sync with the new priorities.It seems to me that the problem lies with the implementation of
__iter__
in tf.data.Dataset, but I posted this issue here since the Colab makes a call toas_numpy_iterator()
on the dataset object, and this is also the implementation of theD4PG
jax agent. Since this is a silent and obscure bug, this effectively eliminates the possibility of changing the baselineD4PG
agent to utilize Prioritized Experience Replay.Minimal Reproducible example:
Output:
Proposed Solution
The problem is immediately solved if
iter(dataset)
is called at each call tonext
. Because of this, I wasn't sure whether to post this issue here or in the tensorflow github, since the problem is with tf.data.Dataset. Personally I would suggest creating a wrapper around tf.data.Dataset for that either makes use of thetake
andbatch
API, or reinitialize theiter
at every call. Because of howreverb
implements sampling, reinitializing the dataset iterator should have no side-effects.Example solution:
Output: ( priorities are updated after every call, which is what we expected).