YeWR / EfficientZero

Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.
GNU General Public License v3.0
839 stars 131 forks source link

Question about the dynamics network #17

Open henrycharlesworth opened 2 years ago

henrycharlesworth commented 2 years ago

Hi,

I was just wondering if you could explain/give some motivation for why the dynamics network works as it does.

I'm looking at a simple ATARI example and when I'm inside: def dynamics(self, encoded_state, reward_hidden, action):

the encoded state is [2, 64, 6, 6] (batch size of 2 - just as a test), and the actions is [2, 1] (integers between 1 and 4).

You then define "actions_one_hot" as torch.ones(2, 1, 6, 6) and say: actions_one_hot = actions[:, :, None, None] * actions_one_hot / self.action_space_size which gives actions_one_hot as [2, 1, 6, 6], with the values copied along the final two dimensions (so each action value is copied 36 times here). Then you concatenate with the encoded state along dim=1 to give a final state which is [2, 65, 6, 6].

Is this a standard thing to do/something that's been done elsewhere? It just feels a bit weird to me. Firstly, the actions are not "one hot encoded" here, so maybe the variable names aren't perfect (but that doesn't really matter I guess). I suppose it makes sense in that you probably want to be able to apply convolutions to the joint state/action within the dynamics network. And I guess with n_actions=4 this is fine, but it feels like this approach would probably break with a larger discrete action space, right?

Anyway if you have the time I'd be interested to hear your motivation/reasoning behind this, thanks!

YeWR commented 2 years ago

Thank you for your comments.

Here we scatter the actions to planes to ensure the shape is the same as the shape of the feature plane (eg: feature is B x 64 x 6 x 6 and action is B x 1 x 6 x 6). And we choose a / action_space to scale the actions.

I agree that it is not a VERY good or natural way to shape the actions because the distance between (a=1, a=2) is different from that between (a=1, a=3). A Good way is to broadcast the action into shape B x Action_space x 6 x 6 with one-hot labels. But it's a large tensor because of the spatial feature, especially when action space is large. Moreover, when action space=18 or even 81, the current implementation can still work well. So we just keep the current implementation.

By the way, it is interesting to find a better method for the action representation under spatial features (not flatten features). Hope this can help you:)

henrycharlesworth commented 2 years ago

Thanks for the reply.

I guess my thought was it'd be better to broadcast onto the height dimension - so you could one-hot encode the actions and end up with a B x 64 x 6 x 10 tensor (or 64 x 6 x 24 if you have 18 actions). But I can see for a large number of actions this would become a large tensor.