Closed kristian-georgiev closed 3 years ago
@kristian-georgiev thanks for the issue and the thorough report!
@krzentner and @avnishn have been working with MAML lately -- perhaps they can take a look?
Here's a very cursory read:
It looks like the error happens when stacking observations from two different tasks. Within the MAML loss function, these should have shape [batch, time, obs dimensions...]
. In this case, it looks like one task has a batch size of 1 and the other has a batch size of 2, I think because you set the batch size very small in trainer.train(n_epochs=3, batch_size=32)
(batch_size
is measured in time steps). This is surprisingly small and could lead to bias towards the overrepresented task, but it's not logically wrong so I think you've encountered a bug which probably wasn't detected by earlier users who use larger batch sizes.
I am not certain (@naeioi , @krzentner , or @avnishn please check), but I think that torch.stack
is not the right thing to do here, and we actually want torch.cat
, which will join the batch dimension into one big batch. I think that torch.stack
worked previously because it's common practice for all tasks to see the same number of trajectories, in which case torch.stack
doesn't complain about mismatched sizes in the batch dimension.
I confirm that this solves the problem. Thanks for the quick response!
Great! Thanks for providing a script that demonstrated the issue, it made debugging the problem much easier.
Hi, thanks for the amazing library!
I am trying to use
MAMLVPG
withPointEnv
- I made minimal modifications from the MAML VPG half cheetah dir example, but run into a somewhat bizarre seed-dependent issue: rarely, one of the samples has multiple sets of observations and thus has a different shape.I suspect that the issue occurs from the interaction of
PointEnv
withSetTaskSampler
but am not certain what exactly causes this behavior.Below is a minimal example that reproduces this issue.
which on the third epoch produces
I am on the main branch (currently up to https://github.com/rlworkgroup/garage/commit/82b5c33ae0796489a00391f80cb94e41657f5962).
Changing the seed changes the epoch/index of the bug.