google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.47k stars 426 forks source link

Converting Tensorflow Dataset to `iterator` does not sync well with client #293

Open joeryjoery opened 1 year ago

joeryjoery commented 1 year ago

Hi, I accidentally stumbled upon a problem within the tutorial notebook when playing around with the acme and reverb API that causes a weird synchronization behaviour between sampling from the reverb table and updating priorities. Another artifact of this that I encountered is that the very first transition would be consistently repeated until some hidden tensorflow buffer would be flushed.

What I found is that when I would mutate the priorities in a reverb table using client.mutate_priorities(table_name, my_dict) and then create an iterator from the tf.data.Dataset object, then the priorities would update only after flushing a large number of samples. In contrast, if I didn't convert the tf.data.Dataset to an iterator and used the dataset.batch(n); dataset.take(n) interface, it would immediately sync with the new priorities.

It seems to me that the problem lies with the implementation of __iter__ in tf.data.Dataset, but I posted this issue here since the Colab makes a call to as_numpy_iterator() on the dataset object, and this is also the implementation of the D4PG jax agent. Since this is a silent and obscure bug, this effectively eliminates the possibility of changing the baseline D4PG agent to utilize Prioritized Experience Replay.

Minimal Reproducible example:

import warnings
warnings.filterwarnings('ignore')

import acme

from acme import wrappers
from acme.datasets import reverb as datasets
from acme.adders.reverb import sequence
from acme.jax import utils

import tree
import reverb
import jax

import numpy as np

from dm_control import suite

# Create dummy environment with short episodes to easily dichotomize samples
env = suite.load('cartpole', 'balance')
env = wrappers.step_limit.StepLimitWrapper(env, step_limit=5)
spec = acme.make_environment_spec(env)

# Danger: reverb.Table crashes kernel if run > once
table = reverb.Table(
    name='priority_table',
    sampler=reverb.selectors.Prioritized(priority_exponent=0.8),
    remover=reverb.selectors.Fifo(),
    max_size=10_000,
    rate_limiter=reverb.rate_limiters.MinSize(1),
    signature=sequence.SequenceAdder.signature(spec)
)

server = reverb.Server([table], port=None)
client = reverb.Client(f'localhost:{server.port}')

# Construct adder such that only 1 sample is added to table after an episode.
adder = sequence.SequenceAdder(client, sequence_length=6, period=5)

def new_dataset():
    # Clear old data
    client.reset(table.name)
    return datasets.make_reverb_dataset(
        table=table.name, server_address=client.server_address, batch_size=3
    )

def fill_dataset():
    step = env.reset()
    adder.add_first(step)

    action = env.action_spec().generate_value()
    i = 0
    while (not step.last()) and i < 10:
        step = env.step(action)
        adder.add(action, step) 
        i += 1   

    env.close()
    adder.reset()

### Example of expected behaviour
dataset = new_dataset()
fill_dataset()

print('before mutation')
for s in dataset.take(1):
    k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()

    print(s.data.action.numpy().reshape(3, -1))  # (B, T, 1) -> (B, T)
    print('sample priority:', p)

    # Iteratively halve the priorities
    new_priorities = dict(zip(k, p * 0.5))
    client.mutate_priorities(table.name, new_priorities)

print()

print('after mutation')
for s in dataset.take(1):
    # Priorities have been updated --> all probabilities should now be adjusted.

    print(s.data.action.numpy().reshape(3, -1))  # (B, T, 1) -> (B, T)
    print('sample priority:', s.info.priority.numpy())

### Test-cases

print('\nUsing dataset.take')
dataset = new_dataset()
fill_dataset()

# This runs fine
for repeat in range(5):
    for i in range(30): # Flush count guess
        for s in dataset.take(1):
            k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()

            # Exponentially decay the priorities
            new_priorities = dict(zip(k, p * 0.999))
            client.mutate_priorities(table.name, new_priorities)

        for s in dataset.take(1):
            new_p = s.info.priority.numpy().ravel()
            assert not np.isclose(new_p, p).any(), "priorities did not update!"
    else:
        # No break in for loop
        print('No errors!')

print('\nUsing next on iter(dataset) - Problems start here.')
dataset = new_dataset()
fill_dataset()
it = iter(dataset)

# Repeat the test-loop as behaviour strangely changes periodically
for repeat in range(5):

    for i in range(30): # Flush count guess
        s = next(it)
        k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()

        # Iteratively halve the priorities
        new_priorities = dict(zip(k, p * 0.999))
        client.mutate_priorities(table.name, new_priorities)

        s = next(it)
        new_p = s.info.priority.numpy().ravel()

        # Priority mutations now sync extremely slowly
        if not np.isclose(p, new_p).all():
            print(f'Priorities updated at flush-step {i}')
            break
    else:
        # No break in for loop : not reached
        print('No errors!')        

Output:

before mutation
[[-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]]
sample priority: [1. 1. 1.]

after mutation
[[-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]
 [-1. -1. -1. -1. -1.  0.]]
sample priority: [0.5 0.5 0.5]

Using dataset.take
No errors!
No errors!
No errors!
No errors!
No errors!

Using next on iter(dataset) - Problems start here.
Priorities updated at flush-step 24
Priorities updated at flush-step 5
Priorities updated at flush-step 18
Priorities updated at flush-step 5
Priorities updated at flush-step 18

Proposed Solution

The problem is immediately solved if iter(dataset) is called at each call to next. Because of this, I wasn't sure whether to post this issue here or in the tensorflow github, since the problem is with tf.data.Dataset. Personally I would suggest creating a wrapper around tf.data.Dataset for that either makes use of the take and batch API, or reinitialize the iter at every call. Because of how reverb implements sampling, reinitializing the dataset iterator should have no side-effects.

Example solution:


print('\nReinitializing iter on every next call - Problem Solved.')
dataset = new_dataset()
fill_dataset()
it = iter(dataset)  # Ignore this iterator

# Repeat the test-loop as behaviour strangely changes periodically
for repeat in range(5):

    for i in range(30): # Flush count guess
        s = next(iter(dataset))  # CHANGE: call iter(dataset) every time `next` is called
        k, p = s.info.key.numpy().ravel(), s.info.priority.numpy().ravel()

        # Iteratively halve the priorities
        new_priorities = dict(zip(k, p * 0.999))
        client.mutate_priorities(table.name, new_priorities)

        s = next(iter(dataset))  # CHANGE: call iter(dataset) every time `next` is called
        new_p = s.info.priority.numpy().ravel()

        # Priority mutations now sync extremely slowly
        if not np.isclose(p, new_p).all():
            print(f'Priorities updated at flush-step {i}')
            break
    else:
        # No break in for loop : not reached
        print('No errors!')

Output: ( priorities are updated after every call, which is what we expected).

Reinitializing iter on every next call - Problem Solved.
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
Priorities updated at flush-step 0
joeryjoery commented 1 year ago

Another quick fix that I use is to wrap the reverb dataset inside the following class:

class RefreshIterator:
    """tf.data.Dataset fix for slow reverb client synchronization. Wrap around reverb-dataset."""

    __slots__ = ["_iterable"]

    def __init__(self, iterable):
        self._iterable = iterable

    def __iter__(self):
        return self

    def __next__(self):
        return next(iter(self._iterable))

    def next(self):
        return self.__next__()

Use:

dataset = datasets.make_reverb_dataset(
    table=my_table.name, server_address=reverb_client.server_address, batch_size=..., ...
)

jax_dataset = utils.multi_device_put(_NumpyIterator(RefreshIterator(dataset)), ...)

With unfortunately _NumpyIterator a private class in tf.dataset_ops.