google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.52k stars 426 forks source link

Question about R2D2 loss, masking, and episode boundaries #239

Closed wcarvalho closed 2 years ago

wcarvalho commented 2 years ago

This is the R2D2 loss: https://github.com/deepmind/acme/blob/b9ba5953495c76b20b0d7a2e4b2e7301a828b75c/acme/agents/jax/r2d2/learning.py#L140

I believe the data.discounts specify episode boundaries. Inspecting the loss when data.discounts is 0, I found that there are non-zero loss values that are summed into the overall loss. How come those values not masked out? E.g. with something like the following:

mask=(data.discounts > 0).astype(jnp.float32)[:-1]
loss = (0.5 * jnp.square(batch_td_error))
loss = loss*mask
batch_loss = loss.sum(0)/(mask.sum(0)+1e-5)

Am I misunderstanding something?

wcarvalho commented 2 years ago

As a note, when I add masking like above, it leads performance to degrade significantly.

One difference I see is that without masking, Q-values seem to be lower. Right now, I think it's because, without a mask, this loss encourage a Q-fn to predict 0s when data is all zeros, which potentially helps with over-estimation or exploding Q-values? Really confused by this. Thanks again!

wcarvalho commented 2 years ago

So I found that proper masking to be critical for having successor feature based learning agents generalize. I compute my mask as follows:

mask=jnp.concatenate((jnp.ones((1, B)), data.discount[:-1]), axis=0)

I found that I need to add ones to the first time-axis to account for discount=0 when episode is done. Consider length-3 episode that terminates after second state. This leads to:

states:      s1  s2 s3
done:        0   1   0
gammas:   1   0   0
mask:        1  1   1