kzl / decision-transformer

Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.
MIT License
2.33k stars 440 forks source link

Question about the output of the decision transformer #67

Closed Pulsar110 closed 4 months ago

Pulsar110 commented 8 months ago

From the code in here: https://github.com/kzl/decision-transformer/blob/master/gym/decision_transformer/models/decision_transformer.py#L92-L99

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_return(x[:,2])  # predict next return given state and action
        state_preds = self.predict_state(x[:,2])    # predict next state given state and action
        action_preds = self.predict_action(x[:,1])  # predict next action given state

I'm not sure I understand why self.predict_return(x[:, 2]) or self.predict_state(x[:, 2]) is predicting the return/next state given the state and action. From the comment on the top, x[:, 2] is only the action? Am I missing something?

And if this code is correct, what is the use of x[:, 0]?

I have also asked this question in the huggingface/transformers repo: https://github.com/huggingface/transformers/issues/27916

nawta commented 8 months ago

I found the same question, I guess the reason why there's no problem is that Decision Transformer does not use the return_preds (the return in the next timestamp) and return_states. If it's the case of Trajectory Transformer, there should be some bug appearing.

yangyichu commented 8 months ago

I think the comments in min_decision_transformer is easier to understand: https://github.com/nikhilbarhate99/min-decision-transformer/blob/d6694248b48c57c84fc7487e6e8017dcca861b02/decision_transformer/model.py#L152 new action is conditioned on everything before r_t, s_t, while new state and new return is conditioned on everything before r_t, s_t, a_t. So the original implementation is right, I think? @nawta @Pulsar110 As for x[:,0] I don't think we can get anything useful out of that.

kzl commented 4 months ago

Hi -

You can view next token prediction as like:

ind | input | transformer | output
  0 |   R_t |     -->     | x[:,0]
  1 |   s_t |    \-->     | x[:,1]
  2 |   a_t |    \-->     | x[:,2]

Hence x[:,0] sees R_t; x[:,1] sees R_t and s_t; and x[:,2] sees all three.

Therefore, with the above formulation of the predictions, we have:

  1. return_preds = Q(R_t, s_t, a_t) (for time t+1), matching traditional Q functions
  2. state_preds = f(R_t, s_t, a_t) (for time t+1), matching traditional dynamics models s_{t+1} = f(s_t, a_t)
  3. action_preds = pi(R_t, s_t) (for time t), which is the standard return-conditioned policy formulation

If we wanted to make a prediction using x[:,0], it would be akin to predicting state given return. In most RL settings, you observe the state and reward simultaneously right after you take the action, and hence there is no need to predict the state, which is why there is no corresponding predictor for index 0. In fact, some followups to the original work combine both the return and state into one token in order to remove this redundancy.

Sorry I have been away for a very long time.

Best, Kevin