danijar / dreamerv3

Mastering Diverse Domains through World Models
https://danijar.com/dreamerv3
MIT License
1.28k stars 219 forks source link

Lambda Return Calculation Bug? #102

Closed rschiewer closed 5 months ago

rschiewer commented 9 months ago

It seems to me that the score() function of the agent class counts the bootstrap time step twice. Here's the current code:

def score(self, traj, actor=None):
    rew = self.rewfn(traj)
    assert len(rew) == len(traj['action']) - 1, (
        'should provide rewards for all but last action')
    discount = 1 - 1 / self.config.horizon
    disc = traj['cont'][1:] * discount
    value = self.net(traj).mean()
    vals = [value[-1]]
    interm = rew + disc * value[1:] * (1 - self.config.return_lambda)
    for t in reversed(range(len(disc))):
      vals.append(interm[t] + disc[t] * self.config.return_lambda * vals[-1])
    ret = jnp.stack(list(reversed(vals))[:-1])
    return rew, ret, value[:-1]

The line vals = [value[-1]] initializes the container for the lambda returns with the bootstrap and the line interm = rew + disc * value[1:] * (1 - self.config.return_lambda) prepeares the intermediate values for the calculation. However, the last time step of interm should contain rew[-1] + disc[-1] * value[-1] * (1 - self.config.return_lambda). When we now go to the first iteration of the loop, we essentially compute vals.append((rew[-1] + disc[-1] * value[-1] * (1 - self.config.return_lambda)[-1]) + disc[-1] * self.config.return_lambda * value[-1]) or am I mistaking something here? This doesn't seem right to me. Even if the last time step of vals is discarded when stacking and reversing the list, it would influence all previous time steps through the recursive dependency, no?

In the dreamer v2 code this is handled differently and the last time steps are omitted from the loop. However, I struggle to fully comprehend the static_scan() function used in the dreamer v2 code so there might be more things going on in the background.

I'm really unsure whether I'm missing something here so my apologies if this actually turns out to be a non-issue in the end.

belerico commented 5 months ago

As you have written:

$r_t+d_t\cdot v_t \cdot (1-\lambda)+d_t\cdot\lambda\cdot v_t = r_t+d_t\cdot v_t$

So the bootstrap cancels out for the last step.

rschiewer commented 5 months ago

Wow, that is very true now that I look at it with some distance. I'm sorry for not having seen this myself. I guess when I tried to wrap my head around this back then, I got caught up in the details. Thanks for still taking the time to answer my question!