google-research / planet

Learning Latent Dynamics for Planning from Pixels
https://danijar.com/planet
Apache License 2.0
1.18k stars 202 forks source link

KL-divergence losses seem not to be divided by some factor(multi-step prediction count D in paper) #19

Closed sehee382 closed 5 years ago

sehee382 commented 5 years ago

The implementation seems to just sum up all KL-divergences of zero step prediction and overshooting prediction without any normalizing term (like 1/D). Is this intended? Or do I miss something?

piojanu commented 5 years ago

Hi!

Could you provide line number in code? Also, why do you think summing is wrong?

piojanu commented 5 years ago

I've looked into the code and here you have it:

https://github.com/google-research/planet/blob/9cd9abed5b9a8831388f4d9da16e5604cfbd7c20/planet/training/utility.py#L183-L212

Lines 196 and 202 reduce sum KL-divergence losses and then divide them by chunk lengths. Is that what you where looking for?

sehee382 commented 5 years ago

Yes. I am asking about that code, right! Let me explain more about my question. Actually I just want to make sure of what I understand from the codes about the final weight for each divergences.

The 'divergence' losses computed twice, once for zerostep divergences and once for overshooting divergences. Step outputs would be like these: (only includes the latent distribution outputs, seq_length=4, overshooting_len=3) zerostep: --> [prior0, posterior0, prior1, posterior1, prior2, posterior2, prior3, posterior3] overshooting: --> 1-step pred [prior1_from_posterior0, prior2_from_posterior1, prior3_from_posterior2] --> 2-step pred [prior2_from_posterior0, prior3_from_posterior1] --> 3-step pred [prior3_from_posterior0] The averaging is done twice like losses['zerostep_divergence'] = avg(zero-steps divergences), losses['overshooting_divergence'] = avg(overshooting divergences) And finally, these two divergence losses just sumed up here: https://github.com/google-research/planet/blob/9cd9abed5b9a8831388f4d9da16e5604cfbd7c20/planet/training/define_model.py#L104-L109 So the final loss weight for zerosteps divergences (n=seq_len=4) and overshooting divergences (n=6≈O(seq_len*overshooting_len)) are same as '1.0' when default scale is chosen. This means that actual weight for the single KLD(posterior_t||prior_t) of zerosteps is bigger than single KLD(posterior_t||prior_t_from_posterior_t_minus_k) of overshooting. image It looks like the implementation has β1=D, sum(β2,β3,...,βD)=D, β2=β3=...=βD of this original paper (if divergence loss scale=1.0), and provides the 'scale' parameter. As I mentioned above, I just wanted to make sure of this understanding.

piojanu commented 5 years ago

@sehee382 tell me if I've correctly understood what are you thinking about: I can see that zero shot loss misses the scaling 1/D which is included in the paper (lets focus on "divergence", "global divergence" is analogous) here:

Previously, I thought that this is the scaling by 1/D from the equation (line 196): https://github.com/google-research/planet/blob/9cd9abed5b9a8831388f4d9da16e5604cfbd7c20/planet/training/utility.py#L183-L212 But it's not. It's reduce mean over time dimension for all the open loop rollouts (BTW, this rescaling comes from the expectation of KL-devergence in the equation? It shouldn't by just a sum? I'll come back to that later).

The scaling by batch size and max rollout length (1/D in the equation) happens in line 210 (I think comment should say: "Average over the batch and normalize by the maximum OPEN LOOP PREDICTION DISTANCE."). This is because target, prior, posterior and mask have shapes: batch x time (chunk length) x num. of rollouts (open loop predictions) x ... Time dim. we've already reduced above. Now, for overshooting we have num. of rollouts equals D, BUT for zero shot we have num. of rollouts equals 1. It comes from the slicing done in lines 76 and 90 here: https://github.com/google-research/planet/blob/9cd9abed5b9a8831388f4d9da16e5604cfbd7c20/planet/training/define_model.py#L74-L98

So, finally, this reduce from line 210... https://github.com/google-research/planet/blob/9cd9abed5b9a8831388f4d9da16e5604cfbd7c20/planet/training/utility.py#L183-L212 ...for overshooting loss it will divide it by the extra dim of rollouts length (which is 'D'), but for zero shot loss it WILL NOT get divided by the rollouts length (which is 'D'). Therefore, zero shot losses have different scales ('D' times higher) than overshoot losses.

@danijar 1. Are we correct? Is it a bug? 2. What about this mean over time dimension (and spatial dimensions of e.g. image). It's a mean not a sum to normalise for different chunk lengths, right? It doesn't come from the equation somehow, right? It's just in code (as scaling by const. won't make optimisation problem change), RIGHT⏩? 😄

piojanu commented 5 years ago

Ohhhh, I see! In the appendix of the paper:

For latent overshooting, we use D = 50 and set β1 = D and β>1 = 1.

So all calculations are correct and the code is fine, we just don't explicitly multiply by beta = D. Please @danijar confirm this. @sehee382 did my view helped ground your understanding? I had to catch up with you, but I think now we are on the same plate 😄

EDIT: But still, I think the paper is inaccurate. It says that D = 50, but slicing here (line 90): https://github.com/google-research/planet/blob/9cd9abed5b9a8831388f4d9da16e5604cfbd7c20/planet/training/define_model.py#L87-L98 will retrieve D - 1 = 49 elements from rollout dim. (target, prior, posterior, mask have D + 1 elements in rollouts dim. if I understand correctly and we slice from the second to the next-to-last element. In config overshooting is set to D - 1, but then to overshooting function is passed (D - 1) + 1 so overshooting func. produces D future prediction + zero step predictions which gives rollouts length D + 1 which is 50 + 1 = 51 in the default case). It means it will divide the overshooting loss (in reduce_mean from line 210 in planet/planet/training/utility.py) by 49. So scaling is 1/49 or 1/(D-1), not 1/50 or 1/D like stated in the paper. Is THIS a bug then? What do you think @danijar @sehee382?

sehee382 commented 5 years ago

Wow! It was in the appendix of the paper :$ β1 = D and β>1 = 1 That was it! Whether β>1=1 or β>1=49/50 is not a big deal to me. Next time I have a question, I'll look into the paper carefully before asking it. Anyway thanks a lot!

danijar commented 5 years ago

Exactly, you've figured it out already :) Whether it's 1 or 0.98 should not make a difference.