pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.22k stars 292 forks source link

[BUG] Collectors crash with enviornments with non-empty `batch_size` #807

Closed matteobettini closed 1 year ago

matteobettini commented 1 year ago

Describe the bug

Collectors seem to have been designed to work with enviornments with empty batch_size.

Examples of torchrl environments where colllectors do not work are Brax (with a batch_size representing vectorized enviornments) and Vmas (with a batch_size representing vectorized enviornments and number of agents)

For now these are the issues i have identified:

1. self.n_env = self.env.numel()

Problem

The collectors assume that the number of environments is the numel of the batch size, which is not true in the multi agent case.

Solution

frames per batch should be independent of the batch size. It has to achieve the meaning of "frames of dim batch_size per batch" self.frames_per_batch = -(-frames_per_batch // self.n_env) -> self.frames_per_batch = frames_per_batch

2. Rollout ignores env.batch_size

Problem


    def rollout(self) -> TensorDictBase:

         ...

        n = self.env.batch_size[0] if len(self.env.batch_size) else 1
        self._tensordict.set("traj_ids", torch.arange(n).view(self.env.batch_size[:1]))

the last line respects only the first dimension of the batch size and ignores the other

3. split_trajectories operates on existing dimensions of the batch_size

If trajectories are not split, the tensordict outputted by the collector has the same size of the one from the env with a last dimensions of dim frames_per_batch added to it (adopting the solution from problem 1)

If trajectories are split, on the other hand, the last dim is forced to be equal to max_frames_per_traj and the rest of that dim is added to dimension 0, without even knowing what dimension 0 is representing.

My opinion

In my opinion, we should get rid of the max_frames_per_traj parameter. This parameter is effectively acting as an horizon. Its logic can be implemented in the environment either naturally or via a transform (return done after x steps). This parameter has no meaning when envs are multi-dimensional becuase if you are rolling out an env with batch_dim (32,26) for max_frames_per_traj, you are not rolling out a single trajectory, but a batch_size of trajectories which may terminate and be reset at whatever indexes of the batch_size.

If we do get rid of this parameter (by having it in the done flag), we can get rid of split_trajectories, padding, and the relative mask.

At every iteration of the collector we would then only collect frames_per_batch samples of size env.batch_size (samples need to have the dim of the env) and this would just add a final dimension to the output tensordict, which will have batch_size (*env.batch_size, frames_per_batch).

This would change the current logic in that: if you have a normal env and frames_per_batch=6 you will get 6 samples, if you have a paralllel env with 2 workers and frames_per_batch=6 you will get 2 samples of shape (6,) (so effectively 12). But this has to be done as, if we dunno what the batch_size is encoding (parallel workers, or agents, or vector envs), we can never know how to count samples.

The collector rollout will reset only the dimensions which are done.

We can also think of doing all the aformentioned stuff and keep max_frames_per_traj (horizon). then if this param is set and we want to split_traj=True the tensordic batch sizes will become (*env.batch_size, frames_per_batch /max_frames_per_traj , max_frames_per_traj) and stay (*env.batch_size, frames_per_batch) if split_traj=False. I suggest against this solution as (as i stated before) a multi dimension trajectory looses the convetional meaning and thus max_frames_per_traj should not be applied to it.

PR #808 is a proof of concepts for this changes. I just took the SyncDataCollector and removed the features which i think we cannot have if collectors want to be agnostic of the env.batch_size

To Reproduce

You can reproduce with vmas or brax

env_maker_vmas = lambda: VmasEnv("flocking", num_envs=32) # env.batch_size = (4, 32) = (n_agents, n_vectorized_envs)
def env_maker_brax():
        env = BraxEnv(_get_envs()[0], batch_size=(4, 32)) 
        env.set_seed(1)
        return env # env.batch_size = (4, 32) = (n_vectorized envs, n_vectorized_envs)
collector = SyncDataCollector(
        create_env_fn=env_maker_vmas,
        total_frames=2000,
        max_frames_per_traj=-1,
        frames_per_batch=200,
        init_random_frames=-1,
        reset_at_each_iter=False,
        device="cpu",
        passing_device="cpu",
        seed=1,
)
for i, data in enumerate(collector):
    if i == 2:
        print(data)
        break
vmoens commented 1 year ago

frames per batch should be independent of the batch size. It has to achieve the meaning of "frames of dim batch_size per batch"

Not sure I fully agree with that statement. In many cases, you say you want X frames in each batch, regardless of the number of envs you are running in parallel. That allows you to recycle easily a code across machines (e.g. if you have one machine with 1 gpu and another with 8, or one machine with 16 cpus and another with 96 -> you will likely put a different number of envs in your parallel env and a different number of parallel envs in your collector). That being said I see your point. Why not multiplying the number of frames per the number of envs in your MARL task? Essentially, you'd be saying "I want X*A frames per batch, where A is the number of agents and X the 'real' number of frames I want". I agree that it's less than intuitive.

In my opinion, we should get rid of the max_frames_per_traj parameter. This parameter is effectively acting as an horizon. Its logic can be implemented in the environment either naturally or via a transform (return done after x steps).

Totally agree. We now have a transform to do that.

split_trajectories operates on existing dimensions of the batch_size and we can get rid of split_trajectories

Note: you can deactive it if it crashes. Simply put split_traj=False.

 My take on all this

Batched envs and multi-agent share a common feature: one env contains multiple envs that act somewhat independently. What is important to me with batched envs is that I would like that, in a script, it is easy to set a frame budget for an experiment. For instance, you often have tasks like "solve task X in less than Y frames". Collecting the frames in the collector should take into account how many envs are being run in parallel and possibly (although we don't do it yet using the collector) the frame_skip.

The question is how to make this compatible with the multi-agent.

One option is to indicate how many of the first dimensions of the env are to be considered as a batch of undistinguishable envs. That would leave the space for multi-agent people to customize this, and for others to collect frames as they should. Something liike

collector = SyncDataCollector(env_constructor, policy, leading_batch_dim=1)  # the first dimension should be considered as the "batch", the rest as the "env" itself

Regarding the bugs and comments:

Wdyt?

matteobettini commented 1 year ago

Thanks for the answer, it clarified some things, here are some of my thoughts:

env_batch_size_mask=(1,0,1)

This means that just parts of the env batch size have to be considered as batch and a frame will be a single entry of those dimensions (so in the case that the dimension left out is the number of agents a frame will be a tnesor of size (num_agents))



The collector would then collect the `frames_per_batch` according to this definition.
vmoens commented 1 year ago

If we get rid of max_frames_per_traj whats the use of split_traj, traj_ids and step_count? These were only used to track trajectories and to pad them to max_frames_per_traj

Not quite. If you have a terminating env (say, Pong), you can have one that ends after 99 and another after 101 steps. since we're running envs on different procs and resetting when needed, we then deliver the trajs along a dim using split_traj. In this case max_frames_per_traj is not set. Still, we need split_traj. traj_ids and step_counts are useful artifacts of that feature.

Also there seems to be some confusion here: if you say you want 4 items in each batch and a max_frames_per_traj of 1000, it is perfectly valid. What will happen is that your trajectory will spread across multiple batches returned by the collector. And split_traj will not create a tensordict with last dim equal to max_frames_per_traj but one with the last dim equal to the max rollout length in this batch.

In the new collector without max_frames_per_traj the shape of data will be (*env.batch_size, frames_per_batch)

Again not necessarily (give what I said above). It is important we keep trajs separated for naive use cases like

data = next(collector)
GAE(data)

GAE operates along the time dimension and not across trajectories. We don't want that to change. And the fact that the trajs are split in the output tensordict is handy. In the case I presented above (Pong) the size of data in this example is not definitive: it depends on the max length of the traj.

vmoens commented 1 year ago

What is a frame in multi-agent? If we have the number of agents in the batch size, frames become a per-agent concept. I tend to consider more a frame to be shared by all agents (see last bullet point) I agree, hence my suggestion of having an arg in the collector to tell apart the dim that should be considered as frames consumed from your budget of frames available for an exp, and the frames that should just be regarded as inherently part of your env config.

parallel_vmas_env.batch_size # (n_parallel_envs, n_agents, n_vec_envs)

I assumed that the vec envs would come before the agents but I agree that it's not that clear. The env_batch_size_mask could be useful

matteobettini commented 1 year ago

What is a frame in multi-agent? If we have the number of agents in the batch size, frames become a per-agent concept. I tend to consider more a frame to be shared by all agents (see last bullet point) I agree, hence my suggestion of having an arg in the collector to tell apart the dim that should be considered as frames consumed from your budget of frames available for an exp, and the frames that should just be regarded as inherently part of your env config.

parallel_vmas_env.batch_size # (n_parallel_envs, n_agents, n_vec_envs)

I assumed that the vec envs would come before the agents but I agree that it's not that clear. The env_batch_size_mask could be useful

They cannot because NestedTensors force heterogenous agents to be dim 0 or at least this was allowing to call get_nestedtensors

matteobettini commented 1 year ago

@vmoens Does something like this make sense to you?

env = ParallelEnv(n_env_workers, VmasEnv("flocking", n_envs, n_agents))
env.batch_size # (n_env_workers, n_agents, n_envs)

# Case 1
collector = MultiSyncDataCollector([env] * n_collector_workers, policy, env_batch_size_mask=(1,1,1), frames_per_batch=frames_per_batch)
for i, data in enumerate(collector):
    data.batch_size  # (n_env_workers, n_agents, n_envs, frames_per_batch / (n_collector_workers * n_env_workers * n_agents * n_envs))

# Case 2
collector = MultiSyncDataCollector([env] * n_collector_workers, policy, env_batch_size_mask=(1,0,1), frames_per_batch=frames_per_batch)
for i, data in enumerate(collector):
   data.batch_size  # (n_env_workers, n_agents, n_envs, frames_per_batch / (n_collector_workers * n_env_workers * n_envs))

We could then find a way to still do split_trajs in this which would add an extra dim in the beginning and force the last dim to be the traj_length padded.

vmoens commented 1 year ago

That's what I had in mind! Looks great

matteobettini commented 1 year ago

Perfect!