Closed wcarvalho closed 2 years ago
As a note, when I add masking like above, it leads performance to degrade significantly.
One difference I see is that without masking, Q-values seem to be lower. Right now, I think it's because, without a mask, this loss encourage a Q-fn to predict 0s when data is all zeros, which potentially helps with over-estimation or exploding Q-values? Really confused by this. Thanks again!
So I found that proper masking to be critical for having successor feature based learning agents generalize. I compute my mask as follows:
mask=jnp.concatenate((jnp.ones((1, B)), data.discount[:-1]), axis=0)
I found that I need to add ones to the first time-axis to account for discount=0 when episode is done. Consider length-3 episode that terminates after second state. This leads to:
states: s1 s2 s3
done: 0 1 0
gammas: 1 0 0
mask: 1 1 1
This is the R2D2 loss: https://github.com/deepmind/acme/blob/b9ba5953495c76b20b0d7a2e4b2e7301a828b75c/acme/agents/jax/r2d2/learning.py#L140
I believe the
data.discounts
specify episode boundaries. Inspecting the loss whendata.discounts
is 0, I found that there are non-zero loss values that are summed into the overall loss. How come those values not masked out? E.g. with something like the following:Am I misunderstanding something?