Closed DriesSmit closed 3 years ago
Hey Dries, could you provide the code to reproduce the issue? Thanks!
Hey Sabela. Please find the below reproducible code. However, maybe this is more of an Acme problem? Should I maybe raise this issue with them? Thanks!
from typing import Iterable
from acme import datasets
from acme import specs as acme_specs
from acme.adders import reverb as adders
import numpy as np
import reverb
import tensorflow as tf
import tree
import time
def spec_like_to_tensor_spec(
paths: Iterable[str], spec: acme_specs.Array
) -> tf.TypeSpec:
"""Convert spec like object to tensorspec.
Args:
paths (Iterable[str]): Spec like path.
spec (acme_specs.Array): Spec to use.
Returns:
tf.TypeSpec: Returned tensorspec.
"""
return tf.TensorSpec.from_spec(spec, name="/".join(str(p) for p in paths))
# Remove batch dimensions.
batch_size = 32
obs_min = np.array([0.0]*1000, dtype=np.float32)
obs_max=np.array([1.0]*1000, dtype=np.float32)
ma_obs_spec = acme_specs.BoundedArray(
shape=obs_min.shape,
dtype="float32",
name="observation",
minimum=obs_min,
maximum=obs_max,
)
queue_spec = tree.map_structure_with_path(spec_like_to_tensor_spec, {"agent": ma_obs_spec})
queue = reverb.Table.queue(
name=adders.DEFAULT_PRIORITY_TABLE,
max_size=10000,
signature=queue_spec,
)
server = reverb.Server([queue], port=None)
can_sample = lambda: queue.can_sample(batch_size)
address = f'localhost:{server.port}'
# The dataset object to learn from.
prefetch_size = 100
dataset_itter = iter(reverb.TrajectoryDataset.from_table_signature(
server_address=address,
table=adders.DEFAULT_PRIORITY_TABLE,
max_in_flight_samples_per_worker=1,
).batch(batch_size, drop_remainder=True).prefetch(prefetch_size))
@tf.function # Add and remove this tf.function to see different behaviour.
def sample_itterator(dataset_itter):
while True:
print("Before next statement.")
data = next(dataset_itter)
print("data: ", data)
time.sleep(0.5)
sample_itterator(dataset_itter)
Thanks for the repro!
I think it's a problem of the tracing that the tf.function
does and that the iterator gets treated as a python function.
You can see a similar behaviour in this example. Without tf.function
it fails with an out-of-bounds error. With tf.function
it iterates indefinitely.
import tensorflow as tf
import time
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
ds_iter = iter(ds)
@tf.function
def sample_iterator(ds_iter):
while True:
print("Before next statement.")
data = next(ds_iter)
print("data: ", data)
time.sleep(0.5)
sample_iterator(ds_iter)
Thanks for the clarification:) I think I will then reach out to Acme on how to solve this problem when using Launchpad as well. Somehow the learner probably needs to get access to the can_sample = lambda: queue.can_sample(batch_size)
function inside the replay process.
Sounds good, I'll close this one then!
Hey.
I am currently using a
reverb.Table.queue
for a PPO implementation. It seems that when I use@tf.function
the queue allows for unlimited pulls from the dataset iterator. I am guessing@tf.functions
somehow bypasses the waiting condition? Let me know if you need further information or if this is a known problem. Thanks!