kzl / decision-transformer

Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.
MIT License
2.38k stars 450 forks source link

aligning action embeddings to other embeddings at line 237 #20

Closed loct824 closed 3 years ago

loct824 commented 3 years ago

at https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/model_atari.py, line 237

token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] + int(targets is None):,:]

I am not quite sure about the usage of checking targets is None. It seems to me it is for 2 cases of inputs:

1.(r_0,s_0,a_0,r_1,s_1,a_1,...,r_k,s_k,a_k) , in that we have all actions for each states for k timesteps, in this case the targets = actions

2.(r_0,s_0,a_0,r_1,s_1,a_1,...r_k,s_k), with the last action a_k to be predicted from s_k, the targets is None in this case (or we could still have the targets (a_0,a1,...a(k-1))?

However it looks to me the quoted line of code would make the token embeddings be presented in the following way when the targets is absent:

(r_0,s_0,a_1,r_1,s_1,a_2,...,r_k,s_k) , in that there is mis-alignment between the states and actions, since it starts from 1 position moved to the right. To me it should be written as

token_embeddings[:,2::3,:] = action_embeddings[:,-states.shape[1] : None if targets else -1,:]

Please see if I have misunderstood the code.

lili-chen commented 3 years ago

In case 2, we are in the evaluation loop, so actions doesn't contain ak. The last element would be a(k-1), not a_k. The check for if targets is None ensures that we get (a_0,a1,...,a(k-1)) and not (a_(-1),a0,...,a(k-1)). So I think line 237 is correct, but let me know if I missed something!

loct824 commented 3 years ago

So you mean in case 2 , the return_to_gos , states and actions to be inputted to the model would be like: r = (r_0,r_1,...,r_k) s = (s_0,s_1,...,sk) a = (a-1,a0,...,a(k-1))

Is it the case?

When should I give / not give the targets as input? Mind to give some more guidance for this?

lili-chen commented 3 years ago

Yes, those are what r, s, and a would be in case 2.

We give the targets as input during training and do not give them during evaluation. This is already handled in the code, see https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/trainer_atari.py#L103 for training and https://github.com/kzl/decision-transformer/blob/master/atari/mingpt/utils.py#L45 for evaluation.