pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.22k stars 292 forks source link

[BUG] Safe policy modules that call _project on a smaller test env crash #1018

Closed matteobettini closed 9 months ago

matteobettini commented 1 year ago

Let's say we have 2 environments:

We create a safe policy

policy_module = ProbabilisticActor(
        module=policy_module,
        spec=env.action_spec, # `[500,3,2]`
        in_keys=["loc", "scale"],
        distribution_class=IndependentNormal,
        safe=True,
    )

and collect with it in the env.

Now, if we run test_env.rollout(policy_module) we get an error when the policy tries to project its action to the test_env.action_spec, because it was created using env.action_spec.

What is the cleanest way to solve this usecase?

Proposed solution

Make the project methods more flexible? Or is the solution to keep a different policy for training and one for testing with the 2 separate specs?

matteobettini commented 1 year ago

The porblem is more extended than i thought and not only for testing policies. It is not only for safe distribution policies, but also for unsafe ones using mins or maxes. These store min and max and call them when forward is called.

For example, TanhNormal calls loc = loc + (self.max - self.min) / 2 + self.min in the loss function.

If we are using a replay buffer where we reshaped the data in orger to sample it like replay_buffer.extend(tensordict_data.reshape(-1)

Then TanhNormal crashs with

  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/probabilistic.py", line 419, in build_dist_from_params
    return self.module[-1].get_dist(tensordict)
  File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/probabilistic.py", line 247, in get_dist
    dist = self.distribution_class(**dist_kwargs, **self.distribution_kwargs)
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/modules/distributions/continuous.py", line 370, in __init__
    self.update(loc, scale)
  File "/Users/Matteo/PycharmProjects/torchrl/torchrl/modules/distributions/continuous.py", line 376, in update
    loc = loc + (self.max - self.min) / 2 + self.min
RuntimeError: The size of tensor a (4096) must match the size of tensor b (640) at non-singleton dimension 0

TL;DR TanhNormal and other classes store the min and max of the action_spec at init time and then when called (for example during the loss) pretend an input that matches that shape. This is very often not the case as data will be reshaped in order to be sampled in the replay buffer

vmoens commented 1 year ago

I usually solve this by passing the spec of a non-batched env, would that work here?

matteobettini commented 1 year ago

For the first issue yes, we could use the test_env.input_spec also for the batched one. But this is not super clear IMO.

For the second issue, that is not applicble i think

vmoens commented 1 year ago

I'm not sure what your feature dimension is, but what I meant is not having a spec of shape [1, 3, 2] but more like [3, 2] or [2] depending on what the features are. Happy to consider other solutions if that does not fit. We could implement a __getitem__ in the specs to be able to reduce it easily:

policy = Actor(..., spec=my_spec[0, 0])

and document that the only remaining batch size of a spec given to a safe module should be non-batched (not necessarily empty but it should be only relevant to feature dimensions)

matteobettini commented 1 year ago

Yes I think we need to demark the batch_size from the feature_size in specs. Especially thinking about when we will have to keep track of the multi-agent dimension in the specs.

To reproduce the bug I am referring to, which cannot be fixed at actor creation time, add the following lines in the ppo_tutorial at line 320

temp_env = SerialEnv(3,lambda: env)
temp_env.start()
env = temp_env
env.action_spec.space.minimum[:] = -0.5

You will see the bug i referred to in TanhNormal

WIthout the last line, this script works, but the last line (which modifies the min) brings us out of a special case in TanhNormal

vmoens commented 1 year ago

I ran this

from torchrl.envs import SerialEnv
temp_env = SerialEnv(3,lambda: env)
temp_env.start()
env = temp_env
env.action_spec.space.minimum[:] = -0.5
with torch.no_grad():
  env.rollout(3, policy_module)

with policy_module created in the script without an error, not sure what is supposed to go wrong. Do you have a self-contained example?

matteobettini commented 1 year ago
base_env = GymEnv("InvertedDoublePendulum-v4")
base_env = TransformedEnv(
    base_env,
    Compose(
        # normalize observations

        DoubleToFloat(
            in_keys=["observation"],
        ),

    ))
env = SerialEnv(3, lambda: base_env)
env.start()
env.action_spec.space.minimum[:] = -0.5

actor_net = nn.Sequential(
    nn.LazyLinear(2 * env.action_spec.shape[-1]),
    NormalParamExtractor(),
)
policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
policy_module = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "min": env.action_spec.space.minimum,
        "max": env.action_spec.space.maximum,
    },
    return_log_prob=True,
    # we'll need the log-prob for the numerator of the importance weights
)

policy_module(env.reset())

td = env.rollout(10)

data_view = td.reshape(-1)

dist = policy_module.get_dist(td)
matteobettini commented 1 year ago

The penultimate reshape line is the one that causes the bug.

And we need that to have randomness in the replaybuffer

vmoens commented 1 year ago

You can do

    distribution_kwargs={
        "min": base_env.action_spec.space.minimum,
        "max": base_env.action_spec.space.maximum,
    },

which solves the issue. Giving the specs of the batched env and passing a td of a different shape is expected to fail, but that is not a legitimate thing to do IMO.

And we need that to have randomness in the replaybuffer

What does that mean?

matteobettini commented 1 year ago

What does that mean?

I thought that if you pass a td of shape [B,T] to the repaly buffer, you are not mixing up trajectories. So I thought that the reason why we do data_view = td.reshape(-1) in the ppo tutorial is to mix trajectories.

    distribution_kwargs={
        "min": base_env.action_spec.space.minimum,
        "max": base_env.action_spec.space.maximum,
    },

Environments which have a batch_size in their default version like vmas (which is vectorized) cannot do this. I would have to do "min": base_env.action_spec.space.minimum[0]

vmoens commented 1 year ago

I'm not sure I get it all, I would need more details You have an environment with a batch-size of [B]. If has specs of batch size [B, *R]. This returns tensordicts of batch size [B, T] or possibly [B, T, *R].

the issue is that you provide a spec of size [B, *R] and then squash the dims in [B * T * R], that is not expected to work. If you could squash dims only in [B*T, R] and pass specs of shape [*R] it would be better. In general, I thought our convention was that all batches that go before the time dimension are irrelevant to MARL/VMAS no?

To take your example, maybe instead of td.view(-1) just do td.flatten(0, 1) and keep the last dim.

If you have another example without SerialEnv where things break I'd be happy to give it a look

matteobettini commented 1 year ago

This returns tensordicts of batch size [B, T] or possibly [B, T, *R]

[B,T]

If you could squash dims only in [B*T, R] and pass specs of shape [*R] it would be better.

Yep that is what I am doing, the only thing is that if you have a spec of shape [B,R] there is not a function in the specs to get R. You have to get the batch_size B from outside the spec (e.g. from the env) to then know what *R is.

I can just manually get [*R] from the specs, that fixes the issue.

In general, I thought our convention was that all batches that go before the time dimension are irrelevant to MARL/VMAS no?

Yes they are irrelevant to MARL, vmas has one dimesion before the time one (in the batch) which is the number of envs (like in brax)

For the example think of brax (in the following pseudocode)

env = BraxEnv("ant", batch_size=(64,32))
env.action_spec.shape # (64,32,n_actions)
env.action_spec.space.minumum[:] = -0.5

distribution_kwargs={
        "min": env.action_spec.space.minimum, # this will not work, I first need to remove the batch size
        "max": env.action_spec.space.maximum[0,0], # this will work, batch size is gone from the spec
    },

Does this make sense?

vmoens commented 1 year ago

Yeah I think that here the best thing is to implement a __getitem__ for the specs such that you can provide the spec with the minimal shape to the module / dist kwargs. Other than that I'm not sure how to deal with this

matteobettini commented 1 year ago

Having a spearation in specs between batch_dim and feature_dim is not possible right? Something like distributions in torch

vmoens commented 1 year ago

Why not just __getitem__? This would solve the problem as well without adding an extra level of complexity. If we do introduce 2 types of shapes (which always confused me with the distributions) we're asking the users to learn more than they need for most use cases. Then we'll eventually need to port that to the envs bc it's the natural thing to do, then someone will ask for that in tensordict too...

I'm a bit afraid that this will be a considerable amount of work that will never end.

matteobettini commented 1 year ago

Yeah I just wanted to know what you thought about it and this makes sense. Although I would say that tensordict has this a bit already. We use batch_size to limit reshping operations and similars to those dimensions.

But I agree, let's not add complexity

matteobettini commented 1 year ago

update on this:

with spec indexing we would be able to do

with no_batch_size():
   env.action_spec

using a function like

def _get_spec_without_batch(self, spec_name):
    spec = getattr(self, spec_name)
    idx = [0,] * len(self.batch_size)
    return spec[idx]

we could then feed these unbatched specs to safe modules

matteobettini commented 1 year ago

Spec indexing introduced in #1105

vmoens commented 9 months ago

Closed by #1105