nicklashansen / tdmpc2

Code for "TD-MPC2: Scalable, Robust World Models for Continuous Control"
https://www.tdmpc2.com
MIT License
272 stars 49 forks source link

Some confusion over the "task embedding" #19

Closed edwhu closed 4 months ago

edwhu commented 4 months ago

Hi, thanks for the great work. Coming here after reading the paper.

I'm a bit confused about the task embedding's initialization during training. The main text doesn't talk about how the task embedding $e$ is being chosen / inferred during training, which implied to me that is it just a single learnable vector.

However, in the appendix, I see that the task embedding is actually (T, 96). So that seems to suggest that when we train on a multi-task trajectory dataset, we know the task IDs of each trajectory, and we can select the corresponding 96-dimensional vector for each task as $e$. Is this correct?

But this also seems to clash with the first layer dimensions of the architecture. For example, the encode's in_features=512+T+A, but shouldn't it be $512+96+A$?

I might be missing something obvious, so I apologize in advance if that's the case.

Screenshot 2024-02-23 at 3 01 20 AM
nicklashansen commented 4 months ago

Hi @edwhu, thanks for your interest in our work! I realize that our notation is a bit inconsistent in the paper wrt this. To be clear, we have learnable embeddings for each task and use an associated task ID to look up the relevant task embedding during training / inference. The initialization happens here: https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/world_model.py#L20 And the retrieval happens here: https://github.com/nicklashansen/tdmpc2/blob/57158282b46ebc5c329c5be9cfe2b0094126d1ca/tdmpc2/common/world_model.py#L78-L92 I hope this clears things up!

edwhu commented 4 months ago

Thank you, that makes sense.