Closed matteobettini closed 1 year ago
frames per batch should be independent of the batch size. It has to achieve the meaning of "frames of dim batch_size per batch"
Not sure I fully agree with that statement. In many cases, you say you want X frames in each batch, regardless of the number of envs you are running in parallel. That allows you to recycle easily a code across machines (e.g. if you have one machine with 1 gpu and another with 8, or one machine with 16 cpus and another with 96 -> you will likely put a different number of envs in your parallel env and a different number of parallel envs in your collector). That being said I see your point. Why not multiplying the number of frames per the number of envs in your MARL task? Essentially, you'd be saying "I want X*A frames per batch, where A is the number of agents and X the 'real' number of frames I want". I agree that it's less than intuitive.
In my opinion, we should get rid of the max_frames_per_traj parameter. This parameter is effectively acting as an horizon. Its logic can be implemented in the environment either naturally or via a transform (return done after x steps).
Totally agree. We now have a transform to do that.
split_trajectories operates on existing dimensions of the batch_size and we can get rid of split_trajectories
Note: you can deactive it if it crashes. Simply put split_traj=False
.
Batched envs and multi-agent share a common feature: one env contains multiple envs that act somewhat independently. What is important to me with batched envs is that I would like that, in a script, it is easy to set a frame budget for an experiment. For instance, you often have tasks like "solve task X in less than Y frames". Collecting the frames in the collector should take into account how many envs are being run in parallel and possibly (although we don't do it yet using the collector) the frame_skip.
The question is how to make this compatible with the multi-agent.
One option is to indicate how many of the first dimensions of the env are to be considered as a batch of undistinguishable envs. That would leave the space for multi-agent people to customize this, and for others to collect frames as they should. Something liike
collector = SyncDataCollector(env_constructor, policy, leading_batch_dim=1) # the first dimension should be considered as the "batch", the rest as the "env" itself
Regarding the bugs and comments:
n = self.env.batch_size[0] if len(self.env.batch_size) else 1
self._tensordict.set("traj_ids", torch.arange(n).view(self.env.batch_size[:1]))
by this
n = max(1, self.env.batch_size[:self.leading_batch_dims].numel())
self._tensordict.set("traj_ids", torch.arange(n).view(self.env.batch_size[:self.leading_batch_dims]))
Wdyt?
Thanks for the answer, it clarified some things, here are some of my thoughts:
max_frames_per_traj
whats the use of split_traj
, traj_ids
and step_count
? These were only used to track trajectories and to pad them to max_frames_per_traj
. In the new collector without max_frames_per_traj
the shape of data will be (*env.batch_size, frames_per_batch)
. If we asked it for split_traj
what would you expect, to have a last dimension added with the length of the longest trajectory encountered and the others padded? I think we should be agnostic to the fact that the user is using a StepCounter transform or not and just report data with the respective dones.collector = SyncDataCollector(env_constructor, policy, leading_batch_dim=1) # the first dimension should be considered as the "batch", the rest as the "env" itself
this is a good idea but we need it to be a mask to apply to the env.batch_size to preserver generality. For example we will have something like:
parallel_vmas_env.batch_size # (n_parallel_envs, n_agents, n_vec_envs)
collector = SyncDataCollector(parallel_vmas_env, policy, env_batch_size_mask=(1,0,1)) # the first dimension should be considered as the "batch", the second not and the third yes
WIth this last parameter then yes we would know exactly what in env.batch_dim
we consider the batch and what has other meaning (like the agents). To reconnect to the previous point, then the definition of a frame could depend on this. For example
env_batch_size_mask=None # or all 1s, the default if you do not use this param
# This means that all the env batch size has to be considered as batch and a frame will be a single entry in a tensor with this batch_size like it is now
env_batch_size_mask=(1,0,1)
The collector would then collect the `frames_per_batch` according to this definition.
If we get rid of max_frames_per_traj whats the use of split_traj, traj_ids and step_count? These were only used to track trajectories and to pad them to max_frames_per_traj
Not quite. If you have a terminating env (say, Pong), you can have one that ends after 99 and another after 101 steps. since we're running envs on different procs and resetting when needed, we then deliver the trajs along a dim using split_traj
. In this case max_frames_per_traj
is not set. Still, we need split_traj
. traj_ids
and step_counts
are useful artifacts of that feature.
Also there seems to be some confusion here:
if you say you want 4 items in each batch and a max_frames_per_traj
of 1000, it is perfectly valid. What will happen is that your trajectory will spread across multiple batches returned by the collector.
And split_traj
will not create a tensordict with last dim equal to max_frames_per_traj
but one with the last dim equal to the max rollout length in this batch.
In the new collector without max_frames_per_traj the shape of data will be (*env.batch_size, frames_per_batch)
Again not necessarily (give what I said above). It is important we keep trajs separated for naive use cases like
data = next(collector)
GAE(data)
GAE operates along the time dimension and not across trajectories. We don't want that to change. And the fact that the trajs are split in the output tensordict is handy.
In the case I presented above (Pong) the size of data
in this example is not definitive: it depends on the max length of the traj.
What is a frame in multi-agent? If we have the number of agents in the batch size, frames become a per-agent concept. I tend to consider more a frame to be shared by all agents (see last bullet point) I agree, hence my suggestion of having an arg in the collector to tell apart the dim that should be considered as frames consumed from your budget of frames available for an exp, and the frames that should just be regarded as inherently part of your env config.
parallel_vmas_env.batch_size # (n_parallel_envs, n_agents, n_vec_envs)
I assumed that the vec envs would come before the agents but I agree that it's not that clear.
The env_batch_size_mask
could be useful
What is a frame in multi-agent? If we have the number of agents in the batch size, frames become a per-agent concept. I tend to consider more a frame to be shared by all agents (see last bullet point) I agree, hence my suggestion of having an arg in the collector to tell apart the dim that should be considered as frames consumed from your budget of frames available for an exp, and the frames that should just be regarded as inherently part of your env config.
parallel_vmas_env.batch_size # (n_parallel_envs, n_agents, n_vec_envs)
I assumed that the vec envs would come before the agents but I agree that it's not that clear. The
env_batch_size_mask
could be useful
They cannot because NestedTensors force heterogenous agents to be dim 0 or at least this was allowing to call get_nestedtensors
@vmoens Does something like this make sense to you?
env = ParallelEnv(n_env_workers, VmasEnv("flocking", n_envs, n_agents))
env.batch_size # (n_env_workers, n_agents, n_envs)
# Case 1
collector = MultiSyncDataCollector([env] * n_collector_workers, policy, env_batch_size_mask=(1,1,1), frames_per_batch=frames_per_batch)
for i, data in enumerate(collector):
data.batch_size # (n_env_workers, n_agents, n_envs, frames_per_batch / (n_collector_workers * n_env_workers * n_agents * n_envs))
# Case 2
collector = MultiSyncDataCollector([env] * n_collector_workers, policy, env_batch_size_mask=(1,0,1), frames_per_batch=frames_per_batch)
for i, data in enumerate(collector):
data.batch_size # (n_env_workers, n_agents, n_envs, frames_per_batch / (n_collector_workers * n_env_workers * n_envs))
We could then find a way to still do split_trajs
in this which would add an extra dim in the beginning and force the last dim to be the traj_length padded.
That's what I had in mind! Looks great
Perfect!
Describe the bug
Collectors seem to have been designed to work with enviornments with empty
batch_size
.Examples of torchrl environments where colllectors do not work are Brax (with a batch_size representing vectorized enviornments) and Vmas (with a batch_size representing vectorized enviornments and number of agents)
For now these are the issues i have identified:
1.
self.n_env = self.env.numel()
Problem
The collectors assume that the number of environments is the numel of the batch size, which is not true in the multi agent case.
Solution
frames per batch should be independent of the batch size. It has to achieve the meaning of "frames of dim batch_size per batch"
self.frames_per_batch = -(-frames_per_batch // self.n_env)
->self.frames_per_batch = frames_per_batch
2. Rollout ignores
env.batch_size
Problem
the last line respects only the first dimension of the batch size and ignores the other
3.
split_trajectories
operates on existing dimensions of thebatch_size
If trajectories are not split, the tensordict outputted by the collector has the same size of the one from the env with a last dimensions of dim
frames_per_batch
added to it (adopting the solution from problem 1)If trajectories are split, on the other hand, the last dim is forced to be equal to
max_frames_per_traj
and the rest of that dim is added to dimension 0, without even knowing what dimension 0 is representing.My opinion
In my opinion, we should get rid of the
max_frames_per_traj
parameter. This parameter is effectively acting as an horizon. Its logic can be implemented in the environment either naturally or via a transform (return done after x steps). This parameter has no meaning when envs are multi-dimensional becuase if you are rolling out an env with batch_dim (32,26) formax_frames_per_traj
, you are not rolling out a single trajectory, but a batch_size of trajectories which may terminate and be reset at whatever indexes of the batch_size.If we do get rid of this parameter (by having it in the done flag), we can get rid of
split_trajectories
, padding, and the relative mask.At every iteration of the collector we would then only collect
frames_per_batch
samples of sizeenv.batch_size
(samples need to have the dim of the env) and this would just add a final dimension to the output tensordict, which will have batch_size(*env.batch_size, frames_per_batch)
.This would change the current logic in that: if you have a normal env and
frames_per_batch
=6 you will get 6 samples, if you have a paralllel env with 2 workers andframes_per_batch
=6 you will get 2 samples of shape (6,) (so effectively 12). But this has to be done as, if we dunno what the batch_size is encoding (parallel workers, or agents, or vector envs), we can never know how to count samples.The collector rollout will reset only the dimensions which are done.
We can also think of doing all the aformentioned stuff and keep
max_frames_per_traj
(horizon). then if this param is set and we want tosplit_traj=True
the tensordic batch sizes will become(*env.batch_size, frames_per_batch /max_frames_per_traj , max_frames_per_traj)
and stay(*env.batch_size, frames_per_batch)
ifsplit_traj=False
. I suggest against this solution as (as i stated before) a multi dimension trajectory looses the convetional meaning and thusmax_frames_per_traj
should not be applied to it.PR #808 is a proof of concepts for this changes. I just took the
SyncDataCollector
and removed the features which i think we cannot have if collectors want to be agnostic of the env.batch_sizeTo Reproduce
You can reproduce with vmas or brax