Open Axel-CH opened 4 years ago
Do you mean you run into an error? If so, can you provide error logs for more context?
Hello,
I'm using a custom "train_eval' function inspired from "https://github.com/tensorflow/agents/blob/master/tf_agents/agents/dqn/examples/v2/train_eval.py"
Here is the code that I use to initialize the envs:
tf_env = tf_py_environment.TFPyEnvironment(parallel_py_environment.ParallelPyEnvironment(parrallel_envs_train))
Code that initialize the priority replay buffer:
num_steps = 1000
sequence_length = 2 # TODO set automatically
priorityzed_table = reverb.Table(
name='my_prioritized_experience_replay_buffer',
sampler=reverb.selectors.Prioritized(0.8),
remover=reverb.selectors.Fifo(),
max_size=1000,
rate_limiter=reverb.rate_limiters.MinSize(100),
)
rpbf_server = reverb.Server([priorityzed_table])
rpbf_py_client = reverb.Client('localhost:{}'.format(rpbf_server.port))
replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
tf_agent.collect_data_spec,
"my_prioritized_experience_replay_buffer",
local_server=rpbf_server,
sequence_length=sequence_length)
traj_obs = reverb_utils.ReverbAddEpisodeObserver(
rpbf_py_client, "my_prioritized_experience_replay_buffer", max_sequence_length=sequence_length)
initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
tf_env,
collect_policy,
observers=[replay_buffer.add_batch] + train_metrics,
num_episodes=collect_episodes_per_iteration)
collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
tf_env,
collect_policy,
observers=[replay_buffer.add_batch] + train_metrics,
num_episodes=collect_episodes_per_iteration
)
data_spec = tf_agent.collect_data_spec
# Before calling client.dataset ( took that from https://github.com/tensorflow/agents/issues/410 )
get_dtype = lambda x: tf.as_dtype(x.dtype)
get_shape = lambda x: (sequence_length,) + x.shape
shapes = tf.nest.map_structure(get_shape, data_spec)
dtypes = tf.nest.map_structure(get_dtype, data_spec)
# Dataset generates trajectories
dataset = replay_buffer.tf_client.dataset(
'my_prioritized_experience_replay_buffer',
dtypes=dtypes,
shapes=shapes)
iterator = iter(dataset)
So, When i execute the code, I'm getting this error:
Traceback (most recent call last):
File "prepare_best_models_hps.py", line 190, in <module>
app.run(main)
File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "prepare_best_models_hps.py", line 166, in main
train_eval(populated_space, save_models, generate_plots, generate_dataframes, trial_desc)
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/agents/dqn/hp_search_dqn/hp_search_dqn/__init__.py", line 664, in train_eval
time_step, policy_state = collect_driver.run(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
result = self._call(*args, **kwds)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 702, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3171, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 613, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 191, in run
return self._run_fn(
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/utils/common.py", line 185, in with_check_resource_vars
return fn(*fn_args, **fn_kwargs)
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 218, in _run
tf.while_loop(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in loop_body
observer_ops = [observer(traj) for observer in self._observers]
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in <listcomp>
observer_ops = [observer(traj) for observer in self._observers]
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/replay_buffers/replay_buffer.py", line 83, in add_batch
return self._add_batch(items)
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/replay_buffers/reverb_replay_buffer.py", line 173, in _add_batch
raise NotImplementedError(
NotImplementedError: ReverbReplayBuffer does not support `add_batch`. See `reverb_utils.ReverbObserver` for more information on how to add data to the buffer.
[reverb/cc/platform/default/server.cc:64] Shutting down replay server
Obviously the issue is coming from the way I initialize the observer with "add_batch". So I tried another way with this combination:
traj_obs = reverb_utils.ReverbAddEpisodeObserver(
rpbf_py_client, "my_prioritized_experience_replay_buffer", max_sequence_length=sequence_length)
initial_collect_driver = dynamic_episode_driver.DynamicEpisodeDriver(
tf_env,
collect_policy,
observers=traj_obs,
num_episodes=3)
And now I'm getting this error:
Traceback (most recent call last):
File "prepare_best_models_hps.py", line 190, in <module>
app.run(main)
File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/home/cuser/.local/lib/python3.8/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "prepare_best_models_hps.py", line 166, in main
train_eval(populated_space, save_models, generate_plots, generate_dataframes, trial_desc)
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/agents/dqn/hp_search_dqn/hp_search_dqn/__init__.py", line 658, in train_eval
time_step, policy_state = collect_driver.run(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
result = self._call(*args, **kwds)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 702, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3171, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 613, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 191, in run
return self._run_fn(
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/utils/common.py", line 185, in with_check_resource_vars
return fn(*fn_args, **fn_kwargs)
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 218, in _run
tf.while_loop(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/cuser/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/cuser/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in loop_body
observer_ops = [observer(traj) for observer in self._observers]
TypeError: 'ReverbAddEpisodeObserver' object is not iterable
[reverb/cc/platform/default/server.cc:64] Shutting down replay server
Regards,
Just saw this. Taking a look.
You're close; instead of passing traj_obs
to observers, pass the list [traj_obs]
. Let me know if that works.
Ok, now i'm getting a new error:
[reverb/pybind.cc:416] Tensor can't be extracted from the source represented as ndarray: Invalid argument: Provided input could not be interpreted as an ndarray
Traceback (most recent call last):
File "prepare_best_models_hps.py", line 191, in <module>
app.run(main)
File "/home/userone/.local/lib/python3.8/site-packages/absl/app.py", line 299, in run
_run_main(main, args)
File "/home/userone/.local/lib/python3.8/site-packages/absl/app.py", line 250, in _run_main
sys.exit(main(argv))
File "prepare_best_models_hps.py", line 166, in main
train_eval(populated_space, save_models, generate_plots, generate_dataframes, trial_desc)
File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/agents/dqn/hp_search_dqn/hp_search_dqn/__init__.py", line 664, in train_eval
time_step, policy_state = collect_driver.run(
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 787, in __call__
result = self._call(*args, **kwds)
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 830, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 702, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2948, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3319, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3171, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 613, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 191, in run
return self._run_fn(
File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/utils/common.py", line 185, in with_check_resource_vars
return fn(*fn_args, **fn_kwargs)
File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 218, in _run
tf.while_loop(
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/util/deprecation.py", line 574, in new_func
return func(*args, **kwargs)
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2489, in while_loop_v2
return while_loop(
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2687, in while_loop
return while_v2.while_loop(
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 188, in while_loop
body_graph = func_graph_module.func_graph_from_py_func(
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 987, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/home/userone/.local/lib/python3.8/site-packages/tensorflow/python/ops/while_v2.py", line 174, in wrapped_body
outputs = body(*_pack_sequence_as(orig_loop_vars, args))
File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in loop_body
observer_ops = [observer(traj) for observer in self._observers]
File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/drivers/dynamic_episode_driver.py", line 144, in <listcomp>
observer_ops = [observer(traj) for observer in self._observers]
File "/home/userone/dev/shatta/framework-dev/private/agents/tf_agents/replay_buffers/reverb_utils.py", line 143, in __call__
self._writer.append(trajectory)
File "/home/userone/anaconda3/envs/ag3/lib/python3.8/site-packages/reverb/client.py", line 150, in append
self._writer.Append(tree.flatten(data))
TypeError: Append(): incompatible function arguments. The following argument types are supported:
1. (self: reverb.libpybind.Writer, arg0: List[tensorflow::Tensor]) -> Status
Invoked with: <reverb.libpybind.Writer object at 0x7f27b25166b0>, [<tf.Tensor 'driver_loop/Placeholder_1:0' shape=(7,) dtype=int32>, <tf.Tensor 'driver_loop/Placeholder_4:0' shape=(7, 16, 6) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_5:0' shape=(7, 16, 6) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_6:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_7:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_8:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_9:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_10:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_11:0' shape=(7, 16, 22) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_12:0' shape=(7, 16, 5) dtype=float32>, <tf.Tensor 'driver_loop/Placeholder_13:0' shape=(7, 16, 20) dtype=float32>, <tf.Tensor 'driver_loop/clip_by_value_2:0' shape=(7,) dtype=int64>, <tf.Tensor 'driver_loop/add_1:0' shape=(7,) dtype=float32>, <tf.Tensor 'driver_loop/step/step_type:0' shape=(7,) dtype=int32>, <tf.Tensor 'driver_loop/step/reward:0' shape=(7,) dtype=float32>, <tf.Tensor 'driver_loop/step/discount:0' shape=(7,) dtype=float32>]
Note that I have a dictionnary prepocessing layer using multiple tf.keras.models.Sequential([list]) and a tf.keras.models.Sequential([list]) for preprocessing combiner. Apparently it's not a best practice according to this warning:
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'list'>
I don't know yet if the issue with reverb is related to that deprecated Sequential models usage
I will rewrite my model using the functional API, and then retry to use reverb. Will keep you updated
After a longer analysis, I don't think I have the possibility to fix the "sequential format" warning at my model level. Seem to be caused on a higher level in agent. Just to be more clear, here is an update on my issue:
When I execute the train_eval function using the priority replay buffer, i'm getting this warning/error combo:
WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a
<class 'list'> input: [<tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_3/flatten/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_2/flatten_1/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_9/flatten_3/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_8/flatten_2/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_7/flatten_7/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_6/flatten_6/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_5/flatten_4/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_4/flatten_5/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential/flatten_8/Reshape:0' shape=(5, 5)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_1/flatten_9/Reshape:0' shape=(5, 20)
dtype=float32>]
Consider rewriting this model with the Functional API.
W0831 11:13:24.864628 140579592951616 sequential.py:362] Layers in a Sequential model should only have a single input tensor, but we receive a
<class 'list'> input: [<tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_3/flatten/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_2/flatten_1/Reshape:0' shape=(5, 6)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_9/flatten_3/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_8/flatten_2/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_7/flatten_7/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_6/flatten_6/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_5/flatten_4/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_4/flatten_5/Reshape:0' shape=(5, 22)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential/flatten_8/Reshape:0' shape=(5, 5)
dtype=float32>, <tf.Tensor 'driver_loop/QRnnNetwork/EncodingNetwork/sequential_1/flatten_9/Reshape:0' shape=(5, 20)
dtype=float32>]
Consider rewriting this model with the Functional API.
[reverb/pybind.cc:416] Tensor can't be extracted from the source represented as ndarray: Invalid argument: Provided input could not be interpreted as an ndarray
(note: the warning is there also without the priority replay buffer but it's working fine, no error)
I think this error is raised because i'm using a dictionnary prepocessing layers using multiple inputs, and reverb can't extract the tensor from it.
Bellow you will see the structure of my preprocessing layers:
preprocessing_layers = {
'a': tf.keras.models.Sequential(list_of_layers['a']),
'b': tf.keras.models.Sequential(list_of_layers['b']),
'c': tf.keras.models.Sequential(list_of_layers['c']),
'd': tf.keras.models.Sequential(list_of_layers['d']),
'e': tf.keras.models.Sequential(list_of_layers['e']),
'f': tf.keras.models.Sequential(list_of_layers['f']),
'g': tf.keras.models.Sequential(list_of_layers['g']),
'h': tf.keras.models.Sequential(list_of_layers['h']),
'i': tf.keras.models.Sequential(list_of_layers['i']),
'j': tf.keras.models.Sequential(list_of_layers['j']),
}
Observation space:
Dict(a:Box(16, 5), b:Box(16, 20), c:Box(16, 6), d:Box(16, 6), e:Box(16, 22), f:Box(16, 22), g:Box(16, 22), h:Box(16, 22), i:Box(16, 22), j:Box(16, 22))
Observation shape:
{'a': (16, 5), 'b': (16, 20), 'c': (16, 6), 'd': (16, 6), 'e': (16, 22), 'f': (16, 22), 'g': (16, 22), 'h': (16, 22), 'i': (16, 22), 'j': (16, 22)}
Le me know what you think @ebrevdo
Have you tried using tf_agents.network.Sequential instead of keras Sequential? Try that and report back.
First of All, Thank you for this awesome repo, It saved months of my life ;-)
I'm just letting you know that I'm following closely the ongoing implementation of the Priority Replay buffer with Reverb. I already tried to use it on my workload but unfortunately I can't, because i'm using the "parallel_py_environment".
Regards, Axel