google-deepmind / reverb

Reverb is an efficient and easy-to-use data storage and transport system designed for machine learning research
Apache License 2.0
704 stars 92 forks source link

Reverb queue not waiting for data #70

Closed DriesSmit closed 3 years ago

DriesSmit commented 3 years ago

Hey.

I am currently using a reverb.Table.queue for a PPO implementation. It seems that when I use @tf.function the queue allows for unlimited pulls from the dataset iterator. I am guessing @tf.functions somehow bypasses the waiting condition? Let me know if you need further information or if this is a known problem. Thanks!

sabelaraga commented 3 years ago

Hey Dries, could you provide the code to reproduce the issue? Thanks!

DriesSmit commented 3 years ago

Hey Sabela. Please find the below reproducible code. However, maybe this is more of an Acme problem? Should I maybe raise this issue with them? Thanks!

from typing import Iterable
from acme import datasets
from acme import specs as acme_specs
from acme.adders import reverb as adders
import numpy as np
import reverb
import tensorflow as tf
import tree
import time

def spec_like_to_tensor_spec(
    paths: Iterable[str], spec: acme_specs.Array
) -> tf.TypeSpec:
    """Convert spec like object to tensorspec.

    Args:
        paths (Iterable[str]): Spec like path.
        spec (acme_specs.Array): Spec to use.

    Returns:
        tf.TypeSpec: Returned tensorspec.
    """
    return tf.TensorSpec.from_spec(spec, name="/".join(str(p) for p in paths))

# Remove batch dimensions.
batch_size = 32

obs_min = np.array([0.0]*1000, dtype=np.float32)
obs_max=np.array([1.0]*1000, dtype=np.float32)
ma_obs_spec = acme_specs.BoundedArray(
            shape=obs_min.shape,
            dtype="float32",
            name="observation",
            minimum=obs_min,
            maximum=obs_max,
        )

queue_spec = tree.map_structure_with_path(spec_like_to_tensor_spec, {"agent": ma_obs_spec})
queue = reverb.Table.queue(
    name=adders.DEFAULT_PRIORITY_TABLE,
    max_size=10000,
    signature=queue_spec,
    )
server = reverb.Server([queue], port=None)
can_sample = lambda: queue.can_sample(batch_size)
address = f'localhost:{server.port}'

# The dataset object to learn from.
prefetch_size = 100
dataset_itter = iter(reverb.TrajectoryDataset.from_table_signature(
                server_address=address,
                table=adders.DEFAULT_PRIORITY_TABLE,
                max_in_flight_samples_per_worker=1,
            ).batch(batch_size, drop_remainder=True).prefetch(prefetch_size))

@tf.function # Add and remove this tf.function to see different behaviour.
def sample_itterator(dataset_itter):
    while True:
        print("Before next statement.")
        data = next(dataset_itter)

        print("data: ", data)
        time.sleep(0.5)

sample_itterator(dataset_itter)
sabelaraga commented 3 years ago

Thanks for the repro!

I think it's a problem of the tracing that the tf.function does and that the iterator gets treated as a python function.

You can see a similar behaviour in this example. Without tf.function it fails with an out-of-bounds error. With tf.function it iterates indefinitely.

import tensorflow as tf
import time

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])

ds_iter = iter(ds)
@tf.function
def sample_iterator(ds_iter):
  while True:
    print("Before next statement.")
    data = next(ds_iter)
    print("data: ", data)
    time.sleep(0.5)

sample_iterator(ds_iter)
DriesSmit commented 3 years ago

Thanks for the clarification:) I think I will then reach out to Acme on how to solve this problem when using Launchpad as well. Somehow the learner probably needs to get access to the can_sample = lambda: queue.can_sample(batch_size) function inside the replay process.

sabelaraga commented 3 years ago

Sounds good, I'll close this one then!