jannerm / trajectory-transformer

Code for the paper "Offline Reinforcement Learning as One Big Sequence Modeling Problem"
https://trajectory-transformer.github.io
MIT License
455 stars 63 forks source link

purpose of pad_to_full_observation #3

Closed Howuhh closed 2 years ago

Howuhh commented 2 years ago

Hi! First of all, thank you for such an interesting work!

I'm trying to figure out how trajectories are represented in this work. As far as I understand, after transformer blocks we get [batch, block_size, embedding_dim] shapes. In a normal transformer we would just pass this to the head, for example nn.Linear(embedding_dim, vocab_size) and get logits for prediction.

Why wouldn't that work? What's the intuition behind such padding and reshape (and ein linear) that you do? It doesn't seem to be mentioned in the paper.

Also, what is stop token? Seems like there is no special cases for ending in beam plan. Is this just for done?

Thanks!

jannerm commented 2 years ago

Each dimension of the transition has a different meaning (e.g., versus ), so has a different linear head. I batch over all of the different linear heads using einsum in the EinLinear module. Because these linear heads are batched together, the number of inputs to each needs to be identical: we can't have N full transitions and one partial transition of only half of a state. To ensure that all transitions are the same dimension, I pad the last one so that it is transition_dim tokens long.

Howuhh commented 2 years ago

Thanks for the answer! Make sense now. Interestingly enough, I removed that part from my reimplementation (as i not understand it) and it is still working as normal (tested on imitation learning).