Closed mehdimashayekhi closed 3 years ago
The wording is not super clear but it should be correct:
In line 74 (which is a bit cryptic), we form the sequence (R_1, s_1, a_1, R_2, s_2, a_2, ...).
When we unfold the tokens in line 92, x[:,0] corresponds to (R_1, R_2, ...) and x[:,1] corresponds to (s_1, s_2, ...).
As a result, it makes sense because when you predict the action a_1, you want to condition on R_1, s_1 (thus being x[:,1]).
Not in the paper, but in a dynamics model you'd want to predict s_2 from R_1, s_1, a_1, and as a result you'd want to use x[:,2] to add the action to your conditioning set.
thanks for your reply, maybe I am missing something, " x[:,1] corresponds to (s_1, s_2, ...)." that's what I am saying but you are doing this in the code state_preds = self.predict_state(x[:,2])
, shouldn't it be state_preds = self.predict_state(x[:,1])
?
x[:,1,t] corresponds to the same token ("vertical column") as s_t, and x[:,2,t] corresponds to a_t.
When we do state prediction, we typically think of a dynamics model as s_{t+1} = f(s_t, at). In our sequence modeling framework this instead becomes s{t+1} = f(a_t, s_t, Rt, a{t-1}, s{t-1}, R{t-1}, ...).
The embedding x[:,2,t] has access (via the causal self-attention mask) to all of those tokens. I.e., we should predict the state after seeing the action, rather than just directly predicting the state after seeing the previous state (which would be something like s_{t+1} = f(s_t).
Actually, the return prediction is a bit weird right now in the code. Using a similar argument, you'd want to use x[:,1] if you want the return prediction to be like a value function (conditioned on s), or x[:,2] to be like a Q-function (conditioned on s and a).
thanks much, got it, makes sense!!
Hi, interesting paper and thanks for sharing the code, QQ, here https://github.com/kzl/decision-transformer/blob/d28039e97a30edaa6839333a8e12661a89ce0861/gym/decision_transformer/models/decision_transformer.py#L96 and here https://github.com/kzl/decision-transformer/blob/d28039e97a30edaa6839333a8e12661a89ce0861/gym/decision_transformer/models/decision_transformer.py#L97, shouldn't they be
state_preds = self.predict_state(x[:,1])
andaction_preds = self.predict_action(x[:,2])
instead ?