pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.22k stars 293 forks source link

[Feature Request] Batched specs of heterogeneous shape and related stacked tensordicts #766

Open vmoens opened 1 year ago

vmoens commented 1 year ago

Motivation

In multiagent settings, each agent's individual spec can differ. I would be nice to have a way of building heterogeneous composite specs, and carry data using tensordict following this logic.

Solution

  1. StackedCompositeSpec We could use a StackedCompositeSpec that would essentially work as a tuple of boxes in gym:

Constructor 1

input_spec = StackedCompositeSpec(
    action=[NdUnboundedTensorSpec(-3, 3, shape=[3]), NdUnboundedTensorSpec(-3, 3, shape=[5]), ]
)

Constructor 2

input_spec = StackedCompositeSpec(
    [
     CompositeSpec(action=NdUnboundedTensorSpec(-3, 3, shape=[3])), 
     CompositeSpec(action=NdUnboundedTensorSpec(-3, 3, shape=[5]), 
])

This would basically mean that the environment expects an action of shape Size([3]) for the first agent and Size([5]) for the second.

  1. Allowing LazyStackedTensorDict to host tensors of different shape across a dimension

That way we could carry data in ParallelEnv and using the collector while keeping the key with mixed attributes visible to the users. One could also access a nestedtensor provided that not more than one LazyTensorDict layer is used (as we can't currently build nested nested tensors).

TensorDictSequential is already capable of handling lazy stacked tensordicts that have differnt keys. We could also think about allowing it (?) to gather tensors that do not share the same shape for instance, although this is harder to implement as not every module has a precise signature of the input tensor it expected.

cc @matteobettini

matteobettini commented 1 year ago

action: Tensor([3, ...], dtype=torch.float32)

Where did the 3 come from here? Shouldn't it be a 2 (as the first dimension comes from the fact that we stacked 2 tensors)? Or am I missing something?

Apart from this the issue looks perfect. Thank you for transcribing this.

vmoens commented 1 year ago

Edited :)

vmoens commented 1 year ago

Centralizing comments from https://github.com/pytorch/rl/issues/784#issuecomment-1370960512 and https://github.com/pytorch/rl/issues/784#issuecomment-1371026324

The API I would propose will look like this: We would introduce 2 new classes, one for keyed specs and one for plain specs.

from torchrl.data import UnboundedContinuousTensorSpec, CompositeSpec, StackedCompositeSpec, StackedSpec

spec1 = UnboundedContinuousTensorSpec(shape=[3, 4])
spec2 = UnboundedContinuousTensorSpec(shape=[3, 4])

c_spec = StackedCompositeSpec(
    CompositeSpec(action=spec1),
    CompositeSpec(action=spec2)
)

spec = c_spec["action"]
assert isinstance(spec, StackedSpec)
spec.rand()  # returns a nestedtensor
ntensor = spec.zeros()  # returns a nestedtensor
spec.is_in(ntensor)  # works
spec.is_in(ntensor.unbind(0))  # works

c_spec.rand()  # returns a LazyStackedTensorDict
lstack = c_spec.zero()  # returns a LazyStackedTensorDict
c_spec.is_in(lstack)

print(spec)
# StackedSpec(shape=torch.Size([2, 3, *], device="cpu", dtype=torch.float32)
print(c_spec)
# StackedCompositeSpec(
#     "action": StackedSpec(shape=torch.Size([2, 3, *], device="cpu", dtype=torch.float32),
# )

Another question to solve is that se sometimes do stuff like

random_action = env.action_spec.rand(env.batch_size)

to get as many actions as there are envs. Here, the batch size will be equal to the first dimension of the stacked spec. How do we want to deal with this @matteobettini?

Also, what do you think of this feature? If you're happy with it I can assign someone on it.

matteobettini commented 1 year ago

First of all thanks so much for recapping this! It looks good and matches what we thought.

Here are some points:

  1. If your example was:
    
    spec1 = UnboundedContinuousTensorSpec(shape=[3, 4])
    spec2 = MultOneHotDiscreteTensorSpec(shape=[2, 6])

c_spec = StackedCompositeSpec( CompositeSpec(action=spec1), CompositeSpec(action=spec2) )

Would this be allowed?

2. 
> ```
> random_action = env.action_spec.rand(env.batch_size)
> ```

I think the issue you brought up here is extermely important and thanks for spotting this.I'll try to give an example to see if I got it.

Environments (for example vmas) can reach very complex batch_dims. When vmas is used in `ParallelEnv` and rollout, for example, it has a `batch_size` of

```python
(n_parallel_workers,
n_agents, # The dimension which can be heterogeneous,
n_vectorized_envs,
n_rollout_samples)

Now, if you ask a StackedCompositeSpec to give you a sample, it will give a LazyStack with batch_dim c_spec.rand().batch_size = (n_agents)

What just happened is that we leaked part of the batch_size into the spec

If then we call what you mentioned (env.action_spec.rand(env.batch_size)), we have the agent dimension twice!

I have been thinking about this issue for some time and I can think of a few ways to tackle it but I still haven't foung a great one:

The problem here is that to get the nested tensor you have to have the heterogneous dim as the first one, but this should be fine because you can do

random_action_per_worker = random_action[0]
random_action_per_worker.stack_dim # 0
random_action_per_worker.get_nestedtensor("action) # Success!

So overall I think this might be the best solution (give that the last snippet is something feasible) but I think we are facing a spicy issue and have to act carefully.

vmoens commented 1 year ago

The more I think about this the more I'm convinced that batch_size is trying to solve two problems at the same time:

Here are two solutions, one being less disruptive than the other:

Leading shapes of specs must have the shape of env.batch_size

Right now it is assumed that the shape of the specs is unrelated to the batch size. For batched envs, we could say

env.observation_spec = stack([base_env.observation_spec for base_env in base_envs], 0)

(right now this would not work but you get the gist).

That way we would not need to do env.action_spec.rand(env.batch_size) anymore, since we would know that env.action_spec already has the right batch size.

Here are some use cases:

Creating an env.shape attribute that completes env.batch_size

In this scenario, env.shape dictates the shape of the specs: again, their leading dimensions must match the shape of the env.

We would then keep the batched env like they are, and their batch_size should not match the leading dims of the specs.

Here are some use cases:

My opinion

IMO the second option is less clean and more disruptive than the first (also harder to understand and the boundary between the two is not super clear)

 Action items

In both scenarios, the first step would be:

 Final point about parallel execution

With envs like brax or model-based, the batch_size need not be decided upfront: you can pass an input tensordict of shape [X, Y, Z] and the env will figure out that this is your current batch size. Not sure how to make this work with parallel envs, since we create buffers of definite shapes to pass data from proc to proc. The solution may be not to use buffers at all? I still don't really see what the use case is though: one of the advantages of MB or brax is that the users will sometimes want to get the grads of the operations executed by the env. ParallelEnv cannot carry gradients. Maybe in those cases we should design some other object? e.g. If we have one env on cuda:0 and the other on cuda:1 (alternatively: cpu:0 and cpu:1) we can execute them in multithreading (not multiprocessing) and we don't need buffers anymore.

ccing some folks with expertise in MARL for advice: @xiaomengy @eugenevinitsky @PaLeroy @XuehaiPan @KornbergFresnel @PKU-YYang

matteobettini commented 1 year ago

I also have a preference for the first solution. I think the first solution is patricularly nice because the specs get even closer to tensors and tensordicts.

Am I right in thinking that the second with rollouts and paralllel envs would be:

multi-agent, vectorized, batched, rolled out: batch_size=[batch, n_rollouts], shape=[n_vec_envs, n_agents]

vmoens commented 1 year ago

That's up to us to decide, but it'll be confusing no matter what.

With the first we can stick to the motto "time dim at the end"

matteobettini commented 1 year ago

There is one doubt I had for a while in my mind:

In environments that have a heterogeneous dimension, such dimension has to be the first one of the batch_size in order to get the relative nested tensors. But when such environment is wrapped in a ParallelEnv , the number of parallel environment is preappended to the batch_size.

I was wondering if this can be done in heterogenous LazyStackedTensorDicts.

i.e.

env = VmasEnv("simple_crypto", num_envs=32)
env.batch_size # (n_agents, n_vec_envs)
env.rollout(10)["action"].shape # (n_agents, n_vec_envs, 10,  *)
env.rollout(10)["action"].stack_dim = 0

env = ParallelEnv(2, lambda: VmasEnv("simple_crypto", num_envs=32))
env.batch_size # (2, n_agents, n_vec_envs)
env.rollout(10)["action"].shape # (2, n_agents, n_vec_envs, 10,  *)
env.rollout(10)["action"].stack_dim = 1

Are the lazy stacked tensordicts already able to support this? I.e. preappending a dimension to the their batch_size

vmoens commented 1 year ago

You'll get a lazy stack of lazy stacks, yes. That should work

PaLeroy commented 1 year ago

Hey guys! From what I understood, I prefer the first option as it is more intuitive for me.

Also, I got a question maybe out of context. I'm just wondering if you plan to identify the agents in the spec and elsewhere? I mean, for the step function, some MARL env relies on actions dict {"agent1_id": action1, "agent2_id": action2} while others have chosen to rely on a list of actions [action1, action2], identifying agents with their index of the list in a sense. I feel like with these solutions, if I am not wrong, you discard the idea of "custom identification". In the end, having a custom name per agent instead of an integer in a range does not change much, but it may be more convenient for some.

matteobettini commented 1 year ago

@PaLeroy

This is a very interesting point you bring up, thanks! This is done in multi-agent libraries such as PettingZoo (https://github.com/Farama-Foundation/PettingZoo).

The first and simplest solution is that an environment could keep an index_to_name list, which is the len of the number of agents and people can query to retrieve the name from an index.

More complex things can be thought but they would look less nice. For example: one could think of making observation a composite spec with the entry jeys being the agent names. This would bring the n_agents outside the batch_size. The same thing could also be done for action and reward, but for example "done" would have issues as it cannot be composite. So you could not define a per-agent done.

I think keeping everything in tensors (eventually lazy stacks if agents are heterogenous) would be cleaner and lets us benefit from all the goodies of torchrl

PaLeroy commented 1 year ago

I like this idea. I also thought of more complex stuff but things get easily much more complicated.

vmoens commented 1 year ago

Yep +1 on having per-env dicts that link index to name. It's also something we had in nocturne and it's not the easiest to handle. We should use (and document it) "lists" of envs where each has a numerical index. If this number can change, the list has a max number of agents and each env is assigned one and only one place in the list. If they have dedicated names, we store a dict name -> id or id -> name.

matteobettini commented 1 year ago

we can close