tensorflow / agents

TF-Agents: A reliable, scalable and easy to use TensorFlow library for Contextual Bandits and Reinforcement Learning.
Apache License 2.0
2.77k stars 715 forks source link

Correct setup for param max_sequence_length in ReverbAddEpisodeObserver #668

Open JCMiles opened 2 years ago

JCMiles commented 2 years ago

Hi, could you please clarify the description of the parm max_sequence_length in ReverbAddEpisodeObserver. The description is a little bit messy.

max_sequence_length: An integer. max_sequence_length used to write to the replay buffer tables. This defines the size of the internal buffer controlling the upper limit of the number of timesteps which can be referenced in a single prioritized item. Note that this is the maximum number of trajectories across all the cached episodes that you are writing into the replay buffer (e.g. number_of_episodes). max_sequence_length is not a limit of how many timesteps or items that can be inserted into the replay buffer. Note that, since max_sequence_length controls the size of internal buffer, it is suggested not to set this value to a very large number. If the number of steps in an episode is more than max_sequence_length, only items up to max_sequence_length is written into the table.

In my case I have an episode with variable length and max step = 20

            max_ep_length = 20
            replay_buffer_capacity = 1000000

1) agent_server.py

            server = reverb.Server(
                tables=[
                    reverb.Table( 
                        name=reverb_replay_buffer.DEFAULT_TABLE,
                        sampler=sampler,
                        remover=reverb.selectors.Fifo(),
                        rate_limiter=replay_buffer_rate_limiter,
                        max_size=replay_buffer_capacity,
                        max_times_sampled=0,
                        signature=replay_buffer_signature
                    ),
                    reverb.Table( 
                        name=reverb_variable_container.DEFAULT_TABLE,
                        sampler=reverb.selectors.Uniform(),
                        remover=reverb.selectors.Fifo(),
                        rate_limiter=variable_container_rate_limiter,
                        max_size=1,
                        max_times_sampled=0,
                        signature=variable_container_signature
                    ),
                ],
                port=reverb_port)

2) aget_train.py

            ReverbReplayBuffer(
                    agent.collect_data_spec,
                    table_name=reverb_replay_buffer.DEFAULT_TABLE,
                    sequence_length=max_ep_length)

3A) agent_collect.py

            observer = reverb_utils.ReverbAddEpisodeObserver(
                py_client=reverb.Client(),
                table_name=reverb_replay_buffer.DEFAULT_TABLE,
                max_sequence_length=max_ep_length ,
                priority=1,
                bypass_partial_episodes=False
            )

with setup 3A I get this error iin agent_collect.py when adding the trajectories.

          The number of trajectories within the same episode exceeds `max_sequence_length`. Consider increasing the 
         `max_sequence_length` or set `bypass_partial_episodes` to true to bypass the episodes with length more than 
         `max_sequence_length`.

3B) agent_collect.py (e.g. 100)

            observer = reverb_utils.ReverbAddEpisodeObserver(
                py_client=reverb.Client(),
                table_name=reverb_replay_buffer.DEFAULT_TABLE,
                max_sequence_length=100,
                priority=1,
                bypass_partial_episodes=False
            )

with setup 3B (any number >= max_ep_length +1) the data collection runs fine but the experience in agent_train.py is sampled wrongly (batch_size, 21) instead of (batch_size, 20) and I get this error:

            File "agent_train.py", line 166, in train
                loss_info = shaper.run(iterations=config.epochs)
              File "/opt/conda/lib/python3.7/site-packages/tf_agents/train/learner.py", line 246, in run
                loss_info = self._train(iterations, iterator, parallel_iterations)
              File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 862, in __call__
                return self._python_function(*args, **kwds)
              File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 3985, in 
                bound_method_wrapper
                return wrapped_fn(weak_instance(), *args, **kwargs)
              File "/opt/conda/lib/python3.7/site-packages/tf_agents/train/learner.py", line 263, in _train
                loss_info = self.single_train_step(iterator)
              File "/opt/conda/lib/python3.7/site-packages/tf_agents/train/learner.py", line 288, in single_train_step
                (experience, sample_info) = next(iterator)
              File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 761, in __next__
                return self._next_internal()
              File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 747, in _next_internal
                output_shapes=self._flat_output_shapes)
              File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 2728, in 
                iterator_get_next
                _ops.raise_from_not_ok_status(e, name)
              File "/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 6941, in 
                raise_from_not_ok_status
                six.raise_from(core._status_to_exception(e.code, message), None)
              File "<string>", line 3, in raise_from
            tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes at component 0: expected [?,20] 
            but got [512,21]. [Op:IteratorGetNext]

I previously tested my train pipeline in a non distributed way with a regular tf_agents StatefulEpisodicReplayBuffer and it worked as expected, so my guss is something is wrong with the setup of max_sequence_length or internally with something related to trajectory.is_boundary() because seams like I get a trajectory == to max_sequence_length + 1

JCMiles commented 2 years ago

any update on this ?

ebrevdo commented 2 years ago

Apologies; just saw this. What do you want to happen if your episode length goes over 20?

JCMiles commented 2 years ago

My episode cannot be greater than 20 steps by environment design. But can be less if the agent reaches its goal before step 20.

ebrevdo commented 2 years ago

Is it possible this is a one-off error? Try setting the max_sequence_length to 21 exactly?

Alternatively if there's a bug in your env that causes it to sometimes go over 20 steps one way to enforce is to use a TimeLimit wrapper.

JCMiles commented 2 years ago

No this is not a one-off error. I've also built a custom rendering script for the environement to have full control over its internal parameters. So I can confirm that every episode is max 20 steps and I'm already using TimeLimit to control it. I'm still a bit confused about the parameter description I reported above. In the end in a ReverbAddEpisodeObserver 'max_sequence_length' represents the env steps or " the number of trajectories across all the cached episodes that you are writing into the replay buffer (e.g. number_of_episodes)." as it is mentioned in the description? cause to me those are complete different things. And last but not least, if it represents the episode steps why I have to set it to 20+1 to get things working ? thx for your time

ebrevdo commented 2 years ago

I suppose the question I have is whether it's a bug on our end, which would be strange because we have plenty of environments that work just fine.

max_sequence_length is the length of the internal buffer used to send data to reverb, and has nothing to do with the environment.

so the question is: does setting it to 21 get things working or not? this can help us debug. though more helpful would be a small, self-contained example which causes the failure to happen. that'd make it much easier to debug!