thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
8k stars 1.13k forks source link

Implement Decision Transformer for offline RL #626

Open nuance1979 opened 2 years ago

nuance1979 commented 2 years ago

I wonder if anyone is interested in implementing Decision Transformer in Tianshou. I did some research and here's my proposal:

What do you think? @Trinkle23897 @ChenDRAG

Trinkle23897 commented 2 years ago

Decision Transformer is a recurrent/autoregressive-style

Should we first fix tons of issues for RNN?

nuance1979 commented 2 years ago

Should we first fix tons of issues for RNN?

Yes, we should. Who is the right person to do it?

Trinkle23897 commented 2 years ago

I can push gogoduan.

Trinkle23897 commented 2 years ago

gogoduan says he will be available after 5.12 and through all summer break.

guillefix commented 2 years ago

I would be interested in this. I havent given it much thought but I think it would be best to start implementing upside-down RL, or more generally HIM (https://arxiv.org/abs/2111.10364) in a policy agnostic way (the general idea doesn't need to be recurrent even). The idea of using transformers or RNNs as architecture is orthogonal to the RL method I think

Trinkle23897 commented 2 years ago

I would be interested in this.

Thanks in advance!

nuance1979 commented 2 years ago

I tried to add recurrent variant PPO in atari_ppo.py here and here. (Ref: cleanrl's version.) However, of all the Atari games I tried, only Enduro got a reasonable best reward of 1311. Breakout was the worst performing with a best reward of 52. Other games got ~1000 best rewards, which were significantly worse than the non-recurrent variant of PPO, not to mention other even better performing policies.

I believe something is indeed wrong, most probably in the interaction between the collector and the recurrent net. @gogoduan Please feel free to use my code as a starting point. Hope you could help us debug here.

Trinkle23897 commented 2 years ago

I believe something is indeed wrong

486

nuance1979 commented 2 years ago

Any update on this? @gogoduan

AlessandroZavoli commented 2 years ago

i'm interested in decision trasformers too.

atsushi3110 commented 1 year ago

Hi, @nuance1979 !! Thank you for your efforts and contributions.

While I am not an expert in RL, I hope this helps :)

It's possible that next_done elements could be lacked in collected data and not to be zeros to (h, c) of LSTMs in rollout process at the start of next episodes (Is this Collector in this repo ?)

In https://github.com/thu-ml/tianshou/issues/486 @Trinkle23897 wrote

in buffer.sample(). And for the corner case: if all episode has length=3 but we would like to sample trajectory length=4:

# 0 1 2 0 1 2 0 1 2 ...
D F F T F F T F F T ...
sample from above (all possible cases with current implementation):
0 0 1 2
0 0 0 1
0 0 0 0

and this could be changed as follows:

time_____:0 1 2 0 1 2 0 1 2 
done_____:F F T F F T F F T 
next_done:? F F T F F T F F 

? in next_done row means init np.zeros at https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py#L223

sample from above if num_stack=7 would be:

 0 1 2 0 1 2 0 
 1 2 0 1 2 0 1 
 2 0 1 2 0 1 2 

As @vwxyzjn and @araffin pointed out in https://github.com/thu-ml/tianshou/issues/486 , the states could be reset multiple times during training, depending on the number of dones in the data collection, and in clearRL src's naming is next_dones

Thus, calling get_value and get_action_and_value performs reseting procedure using next_done boolean flag as (1.0 - d) at here : https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py#L152-L153

Then, policy net (Actor) and value net (Critic) 's forward pass let (h, c) of LSTMs to be zeros according to next_done flags as a leaf node of computational graphs in L244-L248 (Actor and Critic) and L263-L279 (Critic).

https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py#L244-L248

https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo_atari_lstm.py#L263-L279

Batch instances and/or ReplayBuffer instances have to save next_dones & next_lstm_state attributes in Collector.collect so that next_dones boolean flag could let next_lstm_state be np.zeros like ppo_atari_lstm.py#L152-L153

nuance1979 commented 1 year ago

@atsushi3110 you are right. The core problem is that tianshou's ReplayBuffer flattened all replay data to be one-dimensional, while preserving episodic order, then sample mini batches from it. This is good enough for non-recursive models but for recursive models, it needs a mini batch whose data comes from distinct episodes in a way that the next batch needs to be the episodic "next" data points for each dimension of the mini batch (or the start of a new episode).

We need to implement a new method Batch.split_for_recursive_model() which returns this kind of mini batches, along with done flags so the lstm states can be properly reset at episode boundaries, and use it here.

I believe this should work and I was planning to implement PPO+LSTM on Atari and use CleanRL's implementation to verify it. However I don't have the spare time to work on it myself. Once that's done, it would be easy to continue this PR and have a working implementation of Decision Transformer.

If you could submit a PR, I am more than happy to review it and give you feedbacks.

MischaPanch commented 1 year ago

937 needs to be resolved before this can be worked on. The PR #640 can be used as inspiration or basis for continuing this work