Closed Pulsar110 closed 4 months ago
I found the same question, I guess the reason why there's no problem is that Decision Transformer does not use the return_preds (the return in the next timestamp) and return_states. If it's the case of Trajectory Transformer, there should be some bug appearing.
I think the comments in min_decision_transformer is easier to understand: https://github.com/nikhilbarhate99/min-decision-transformer/blob/d6694248b48c57c84fc7487e6e8017dcca861b02/decision_transformer/model.py#L152 new action is conditioned on everything before r_t, s_t, while new state and new return is conditioned on everything before r_t, s_t, a_t. So the original implementation is right, I think? @nawta @Pulsar110 As for x[:,0] I don't think we can get anything useful out of that.
Hi -
You can view next token prediction as like:
ind | input | transformer | output
0 | R_t | --> | x[:,0]
1 | s_t | \--> | x[:,1]
2 | a_t | \--> | x[:,2]
Hence x[:,0]
sees R_t
; x[:,1]
sees R_t
and s_t
; and x[:,2]
sees all three.
Therefore, with the above formulation of the predictions, we have:
return_preds = Q(R_t, s_t, a_t)
(for time t+1
), matching traditional Q functionsstate_preds = f(R_t, s_t, a_t)
(for time t+1
), matching traditional dynamics models s_{t+1} = f(s_t, a_t)
action_preds = pi(R_t, s_t)
(for time t
), which is the standard return-conditioned policy formulationIf we wanted to make a prediction using x[:,0]
, it would be akin to predicting state given return. In most RL settings, you observe the state and reward simultaneously right after you take the action, and hence there is no need to predict the state, which is why there is no corresponding predictor for index 0. In fact, some followups to the original work combine both the return and state into one token in order to remove this redundancy.
Sorry I have been away for a very long time.
Best, Kevin
From the code in here: https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/models/decision_transformer.py#L92-L99
I'm not sure I understand why self.predict_return(x[:, 2]) or self.predict_state(x[:, 2]) is predicting the return/next state given the state and action. From the comment on the top, x[:, 2] is only the action? Am I missing something?
And if this code is correct, what is the use of x[:, 0]?
I have also asked this question in the huggingface/transformers repo: https://github.com/huggingface/transformers/issues/27916