kzl / decision-transformer

Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.
MIT License
2.38k stars 450 forks source link

state and action prediction #5

Closed mehdimashayekhi closed 3 years ago

mehdimashayekhi commented 3 years ago

Hi, interesting paper and thanks for sharing the code, QQ, here https://github.com/kzl/decision-transformer/blob/d28039e97a30edaa6839333a8e12661a89ce0861/gym/decision_transformer/models/decision_transformer.py#L96 and here https://github.com/kzl/decision-transformer/blob/d28039e97a30edaa6839333a8e12661a89ce0861/gym/decision_transformer/models/decision_transformer.py#L97, shouldn't they be state_preds = self.predict_state(x[:,1]) and action_preds = self.predict_action(x[:,2]) instead ?

kzl commented 3 years ago

The wording is not super clear but it should be correct:

In line 74 (which is a bit cryptic), we form the sequence (R_1, s_1, a_1, R_2, s_2, a_2, ...).

When we unfold the tokens in line 92, x[:,0] corresponds to (R_1, R_2, ...) and x[:,1] corresponds to (s_1, s_2, ...).

As a result, it makes sense because when you predict the action a_1, you want to condition on R_1, s_1 (thus being x[:,1]).

Not in the paper, but in a dynamics model you'd want to predict s_2 from R_1, s_1, a_1, and as a result you'd want to use x[:,2] to add the action to your conditioning set.

mehdimashayekhi commented 3 years ago

thanks for your reply, maybe I am missing something, " x[:,1] corresponds to (s_1, s_2, ...)." that's what I am saying but you are doing this in the code state_preds = self.predict_state(x[:,2]), shouldn't it be state_preds = self.predict_state(x[:,1]) ?

kzl commented 3 years ago

x[:,1,t] corresponds to the same token ("vertical column") as s_t, and x[:,2,t] corresponds to a_t.

When we do state prediction, we typically think of a dynamics model as s_{t+1} = f(s_t, at). In our sequence modeling framework this instead becomes s{t+1} = f(a_t, s_t, Rt, a{t-1}, s{t-1}, R{t-1}, ...).

The embedding x[:,2,t] has access (via the causal self-attention mask) to all of those tokens. I.e., we should predict the state after seeing the action, rather than just directly predicting the state after seeing the previous state (which would be something like s_{t+1} = f(s_t).

Actually, the return prediction is a bit weird right now in the code. Using a similar argument, you'd want to use x[:,1] if you want the return prediction to be like a value function (conditioned on s), or x[:,2] to be like a Q-function (conditioned on s and a).

mehdimashayekhi commented 3 years ago

thanks much, got it, makes sense!!