google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.52k stars 426 forks source link

Use of tf.data.Dataset within tf.function #298

Closed San-Holo closed 1 year ago

San-Holo commented 1 year ago

Hey,

First of all, thanks for this repo, which I found to be an amazing contribution to the RL community !

Recently, an error has appeared during the execution of my code, something that didn't happen until now. The agent I use is based on the provided example of the distributed MPO agent, modified to work with sequences and the corresponding adder.

The environment I use is relatively slow (1 sec/iter and 5 sec for a complete reset), so I rely on the "wait" operation on the Reverb side to make sure that the learner blocks while waiting for enough experiences to be available in the buffer. I wanted it to be done thanks to the __min_replaysize parameter given to the SampleToInsertRatio object that I use:

limiter = reverb.rate_limiters.SampleToInsertRatio(
          min_size_to_sample=self._min_replay_size,
          samples_per_insert=self._samples_per_insert,
          error_buffer=error_buffer)

_make_reverbdataset(...) is called with the latter to create a Dataset object which will be used in the learner, written as such in my code with _batchsize and _prefetchsize properly set at __init()__:

dataset = datasets.make_reverb_dataset(
        server_address=replay.server_address,
        batch_size=self._batch_size,
        prefetch_size=self._prefetch_size, 
        num_parallel_calls = tf.data.AUTOTUNE,)

Once in graph mode within next calls to learner's _step() function, I get an error from next() on dataset iterator, with most recent call in traceback:

Detected at node 'IteratorGetNext' defined at (most recent call last):
    File "/Applications/software/anaconda/python37/lib/python3.7/threading.py", line 890, in _bootstrap
      self._bootstrap_inner()
    File "/Applications/software/anaconda/python37/lib/python3.7/threading.py", line 926, in _bootstrap_inner
      self.run()
    File "/Applications/software/anaconda/python37/lib/python3.7/threading.py", line 870, in run
      self._target(*self._args, **self._kwargs)
    File "/home/SK268679/PhD/venvs/vs/lib/python3.7/site-packages/launchpad/launch/worker_manager.py", line 238, in run_inner
      future.set_result(f())
    File "/home/SK268679/PhD/venvs/vs/lib/python3.7/site-packages/launchpad/nodes/python/node.py", line 75, in _construct_function
      return functools.partial(self._function, *args, **kwargs)()
    File "/home/SK268679/PhD/venvs/vs/lib/python3.7/site-packages/launchpad/nodes/courier/node.py", line 130, in run
      instance.run()
    File "/home/SK268679/PhD/venvs/vs/lib/python3.7/site-packages/acme/core.py", line 161, in run
      self.step()
    File "../agents/learning.py", line 688, in step
      fetches = self._step()
    File "../agents/learning.py", line 287, in _step
      inputs = iter(self._dataset).get_next()
Node: 'IteratorGetNext'
2 root error(s) found.
  (0) OUT_OF_RANGE:  End of sequence
     [[{{node IteratorGetNext}}]]
     [[IteratorGetNext/_34]]
  (1) OUT_OF_RANGE:  End of sequence
     [[{{node IteratorGetNext}}]]
0 successful operations.
0 derived errors ignored. [Op:__inference__step_14999]

Following advices from #152, https://github.com/deepmind/reverb/issues/70 or #293 (even if it concerns a JAX setup while I work with a tf one), I still get the same error. What we see in the traceback consists of what was adviced in those issues, which is directly passing the tf.dataset object to the tf.function, instead of the tf.data.Iterator obtained from iter() like it's done in the MPO template example.

I checked if any data is sent to the replay buffer, and everything's seems fine on this matter. Not enough elements are available to the learner at next() call, which seems logical since the environment takes its time to fill in at least one sequence (not more than 5 iterations when the problematic line is reached). Hence, I supposed that the error displayed here comes either from the tf.function that does not take into account the wait operation, or from a misunderstanding of reverb Table setup.

Did I miss something regarding buffer control and its min_replay_size parameter ? Or is it indeed something related to the tf.function ?

Thanks in advance !

San-Holo commented 1 year ago

Hey,

I found a solution by looking empirically for proper parameters while creating the tensorflow Dataset, and checking for data availability before using next(). Even if it does not rely on min_replay_size, this quick fix works like a charm !