takuseno / d3rlpy

An offline deep reinforcement learning library
https://takuseno.github.io/d3rlpy
MIT License
1.28k stars 232 forks source link

Differences in RTG computation between inference and training time #401

Closed john-galt-10 closed 1 month ago

john-galt-10 commented 2 months ago

I'm working with the Decision Transformer implementation of this library.

The computation of the RTG seems to be inconsistent between training and inference time.

Inside the BasicTrajectorySlicer class (used @ training time) rewards to go for each step are computed as:

rewards = episode.rewards[start:end]

ret = np.sum(episode.rewards[start:])

all_returns_to_go = ret - np.cumsum(episode.rewards[start:], axis=0)

returns_to_go = all_returns_to_go[:actual_size].reshape((-1, 1))

Inside the StatefulTransformerWrapper the RTG for inference time is fixed at an initial value TARGET_RETURN and later decreased after receiving the reward for the previous step.

Let's suppose a training trajectory has received rewards like [ 1, 2, 3, 4, 5 ]. RTGs used at training time would turn out to be: [ 14, 12, 9, 5, 0 ]. If we see an identical trajectory at test time and choose TARGET_RETURN=15 (the total return of the aforementioned trajectory), we would like to see the same sequence of RTGs but we actually experience something different: at t=0 we would condition on the chosen target return 15, at t=1 we would decrease it to 14 and so on.

To wrap up, the RTG sequence seen @ training time would be [ 14, 12, 9, 5, 0 ] while @ inference time we would have [15, 14, 12, 9, 5]. I would expect to see the latter also at training time.

Is there any particular reason behind this choice? Thank you in advance!!

takuseno commented 2 months ago

@john-galt-10 Thanks for the issue. It's actually a good point. I realized that the current implementation is wrong. In the latest commit, I've fixed this issue to match calculation between training and eval: https://github.com/takuseno/d3rlpy/commit/8436ca128d820a32d2fc7776bae5d97897e8968f

If you install d3rlpy from source, you can use the fixed version. Thank you for reporting this!

takuseno commented 1 month ago

I believe that the commit addressed the issue. Feel free to reopen this if there is any further discussion.