kzl / decision-transformer

Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.
MIT License
2.33k stars 440 forks source link

Global position embedding and timesteps look wrong in atari #51

Closed nzw0301 closed 1 year ago

nzw0301 commented 1 year ago

I'm not familiar with position encoding, but if my understanding is correct, for each sample batch, global_pos_emb is used only for a single timestep in the atari code. Is it the intended one?

Essentially, the current code computes global position encoding in the following way:

import torch

max_time_step = 10
emb_dim = 7
batch_size = 2
timesteps = torch.tensor([1, 3]).view(batch_size, 1, 1). # the implementation of the dataset returns a relative index in an episode of the first state.

global_pos_emb = torch.rand(1, max_time_step, emb_dim)
all_global_pos_emb = torch.repeat_interleave(
    global_pos_emb, batch_size, dim=0
)  # batch_size, traj_length, n_embd 

torch.gather(
    all_global_pos_emb,
    1,
    torch.repeat_interleave(timesteps, emb_dim, dim=-1),
)
# shape is (batch_size, 1, emb_dim), 
lili-chen commented 1 year ago

I believe this is the intended behavior. The global position embedding should be the same for every element of the sequence and the local position embedding is is different for each element of the sequence. In any case, we found that the specifics of the position embedding didn't make much of a difference in the results. Hope this helps!

nzw0301 commented 1 year ago

Thank you for your clarification!