Closed matteobettini closed 9 months ago
The porblem is more extended than i thought and not only for testing policies.
It is not only for safe distribution policies, but also for unsafe ones using mins or maxes. These store min and max and call them when forward
is called.
For example, TanhNormal
calls loc = loc + (self.max - self.min) / 2 + self.min
in the loss function.
If we are using a replay buffer where we reshaped the data in orger to sample it like replay_buffer.extend(tensordict_data.reshape(-1)
Then TanhNormal crashs with
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/probabilistic.py", line 419, in build_dist_from_params
return self.module[-1].get_dist(tensordict)
File "/Users/Matteo/PycharmProjects/tensordict/tensordict/nn/probabilistic.py", line 247, in get_dist
dist = self.distribution_class(**dist_kwargs, **self.distribution_kwargs)
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/modules/distributions/continuous.py", line 370, in __init__
self.update(loc, scale)
File "/Users/Matteo/PycharmProjects/torchrl/torchrl/modules/distributions/continuous.py", line 376, in update
loc = loc + (self.max - self.min) / 2 + self.min
RuntimeError: The size of tensor a (4096) must match the size of tensor b (640) at non-singleton dimension 0
TL;DR TanhNormal
and other classes store the min and max of the action_spec at init time and then when called (for example during the loss) pretend an input that matches that shape. This is very often not the case as data will be reshaped in order to be sampled in the replay buffer
I usually solve this by passing the spec of a non-batched env, would that work here?
For the first issue yes, we could use the test_env.input_spec
also for the batched one. But this is not super clear IMO.
For the second issue, that is not applicble i think
I'm not sure what your feature dimension is, but what I meant is not having a spec of shape [1, 3, 2]
but more like [3, 2]
or [2]
depending on what the features are.
Happy to consider other solutions if that does not fit. We could implement a __getitem__
in the specs to be able to reduce it easily:
policy = Actor(..., spec=my_spec[0, 0])
and document that the only remaining batch size of a spec given to a safe module should be non-batched (not necessarily empty but it should be only relevant to feature dimensions)
Yes I think we need to demark the batch_size from the feature_size in specs. Especially thinking about when we will have to keep track of the multi-agent dimension in the specs.
To reproduce the bug I am referring to, which cannot be fixed at actor creation time, add the following lines in the ppo_tutorial at line 320
temp_env = SerialEnv(3,lambda: env)
temp_env.start()
env = temp_env
env.action_spec.space.minimum[:] = -0.5
You will see the bug i referred to in TanhNormal
WIthout the last line, this script works, but the last line (which modifies the min) brings us out of a special case in TanhNormal
I ran this
from torchrl.envs import SerialEnv
temp_env = SerialEnv(3,lambda: env)
temp_env.start()
env = temp_env
env.action_spec.space.minimum[:] = -0.5
with torch.no_grad():
env.rollout(3, policy_module)
with policy_module
created in the script without an error, not sure what is supposed to go wrong. Do you have a self-contained example?
base_env = GymEnv("InvertedDoublePendulum-v4")
base_env = TransformedEnv(
base_env,
Compose(
# normalize observations
DoubleToFloat(
in_keys=["observation"],
),
))
env = SerialEnv(3, lambda: base_env)
env.start()
env.action_spec.space.minimum[:] = -0.5
actor_net = nn.Sequential(
nn.LazyLinear(2 * env.action_spec.shape[-1]),
NormalParamExtractor(),
)
policy_module = TensorDictModule(
actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)
policy_module = ProbabilisticActor(
module=policy_module,
spec=env.action_spec,
in_keys=["loc", "scale"],
distribution_class=TanhNormal,
distribution_kwargs={
"min": env.action_spec.space.minimum,
"max": env.action_spec.space.maximum,
},
return_log_prob=True,
# we'll need the log-prob for the numerator of the importance weights
)
policy_module(env.reset())
td = env.rollout(10)
data_view = td.reshape(-1)
dist = policy_module.get_dist(td)
The penultimate reshape line is the one that causes the bug.
And we need that to have randomness in the replaybuffer
You can do
distribution_kwargs={
"min": base_env.action_spec.space.minimum,
"max": base_env.action_spec.space.maximum,
},
which solves the issue. Giving the specs of the batched env and passing a td of a different shape is expected to fail, but that is not a legitimate thing to do IMO.
And we need that to have randomness in the replaybuffer
What does that mean?
What does that mean?
I thought that if you pass a td of shape [B,T] to the repaly buffer, you are not mixing up trajectories.
So I thought that the reason why we do data_view = td.reshape(-1)
in the ppo tutorial is to mix trajectories.
distribution_kwargs={ "min": base_env.action_spec.space.minimum, "max": base_env.action_spec.space.maximum, },
Environments which have a batch_size in their default version like vmas (which is vectorized) cannot do this.
I would have to do "min": base_env.action_spec.space.minimum[0]
I'm not sure I get it all, I would need more details
You have an environment with a batch-size of [B]
.
If has specs of batch size [B, *R]
.
This returns tensordicts of batch size [B, T]
or possibly [B, T, *R]
.
the issue is that you provide a spec of size [B, *R]
and then squash the dims in [B * T * R]
, that is not expected to work.
If you could squash dims only in [B*T, R]
and pass specs of shape [*R]
it would be better. In general, I thought our convention was that all batches that go before the time dimension are irrelevant to MARL/VMAS no?
To take your example, maybe instead of td.view(-1)
just do td.flatten(0, 1)
and keep the last dim.
If you have another example without SerialEnv
where things break I'd be happy to give it a look
This returns tensordicts of batch size
[B, T]
or possibly[B, T, *R]
[B,T]
If you could squash dims only in
[B*T, R]
and pass specs of shape[*R]
it would be better.
Yep that is what I am doing, the only thing is that if you have a spec of shape [B,R] there is not a function in the specs to get R. You have to get the batch_size B from outside the spec (e.g. from the env) to then know what *R is.
I can just manually get [*R] from the specs, that fixes the issue.
In general, I thought our convention was that all batches that go before the time dimension are irrelevant to MARL/VMAS no?
Yes they are irrelevant to MARL, vmas has one dimesion before the time one (in the batch) which is the number of envs (like in brax)
For the example think of brax (in the following pseudocode)
env = BraxEnv("ant", batch_size=(64,32))
env.action_spec.shape # (64,32,n_actions)
env.action_spec.space.minumum[:] = -0.5
distribution_kwargs={
"min": env.action_spec.space.minimum, # this will not work, I first need to remove the batch size
"max": env.action_spec.space.maximum[0,0], # this will work, batch size is gone from the spec
},
Does this make sense?
Yeah I think that here the best thing is to implement a __getitem__
for the specs such that you can provide the spec with the minimal shape to the module / dist kwargs. Other than that I'm not sure how to deal with this
Having a spearation in specs between batch_dim and feature_dim is not possible right? Something like distributions in torch
Why not just __getitem__
?
This would solve the problem as well without adding an extra level of complexity.
If we do introduce 2 types of shapes (which always confused me with the distributions) we're asking the users to learn more than they need for most use cases. Then we'll eventually need to port that to the envs bc it's the natural thing to do, then someone will ask for that in tensordict too...
I'm a bit afraid that this will be a considerable amount of work that will never end.
Yeah I just wanted to know what you thought about it and this makes sense. Although I would say that tensordict has this a bit already. We use batch_size to limit reshping operations and similars to those dimensions.
But I agree, let's not add complexity
update on this:
with spec indexing we would be able to do
with no_batch_size():
env.action_spec
using a function like
def _get_spec_without_batch(self, spec_name):
spec = getattr(self, spec_name)
idx = [0,] * len(self.batch_size)
return spec[idx]
we could then feed these unbatched specs to safe modules
Spec indexing introduced in #1105
Closed by #1105
Let's say we have 2 environments:
env
for collection with action spec with shape[500,3,2]
, with [500] batch_sizetest_env
for testing with action spec with shape[1,3,2]
, with [1] batch_sizeWe create a safe policy
and collect with it in the
env
.Now, if we run
test_env.rollout(policy_module)
we get an error when the policy tries to project its action to thetest_env.action_spec
, because it was created usingenv.action_spec
.What is the cleanest way to solve this usecase?
Proposed solution
Make the
project
methods more flexible? Or is the solution to keep a different policy for training and one for testing with the 2 separate specs?