edgarschnfld / CADA-VAE-PyTorch

Official implementation of the paper "Generalized Zero- and Few-Shot Learning via Aligned Variational Autoencoders" (CVPR 2019)
MIT License
283 stars 57 forks source link

Some confusion about the warmup schedule #25

Closed Hanzy1996 closed 3 years ago

Hanzy1996 commented 3 years ago

According to the codes, you have scaled the loss terms according to the warmup schedule. I am still confused about this operation. How does this strategy work and why do we need it? Is there any reference related to this operation?

Much appreciation!

edgarschnfld commented 3 years ago

First let me copy our explanation directly from paper, from the section "implementation details":

After individual VAEs learn to encode features of only their specific datatype for some epochs, 
we also start to compute cross- and distribution alignment losses. δ is increased from epoch 6 
to epoch 22 by a rate of 0.54 per epoch, while γ is increased from epoch 21 to 75 by 0.044 per 
epoch. For the KL-divergence we use an annealing scheme [3], in which we increase the weight
 β of the KL-divergence by a rate of 0.0026 per epoch until epoch 90. A KL-annealing scheme 
serves the purpose of first letting the VAE learn “useful” representations before they are 
“smoothed” out, since the KL-divergence would be otherwise a very strong regularizer [3].

To put it in other words: Our goal is to align latent representations. For that, it is useful to let the autoencoders learn some representation in the first place. Then, after some epochs have passed, you bring those representations closer together by slowly increasing the losses that bring the latents closer together (cross-reconstruction and distribution alignment). In other words, this helps from a practical standpoint, because the objectives that align the representations could "overshadow" the losses that make the autoencoder learn a useful representation. I got this idea from (Bowman 2016), where they anneal the weight of the KL-divergence of a VAE with a schedule, since the KL-divergence is too strong of a constraint at the early stages of learning. We do the same here, for the same reason, but not only for the KL-divergence. In this implementation, we define a start epoch, end epoch and the final factor of the weights (=the factor we multiply the loss with). So if (start epoch, end epoch, final factor ) == (6, 22, 8.13) then between epoch 6 and 2 the weight increases until it reaches 8.13 at epoch 20. Before the start epoch the weight is 0, after the end epoch it does not change anymore. Note: In the paper we specify the rate of increase, while in the code the final factor is specified.

S. R. Bowman, L. Vilnis, O. Vinyals, A. Dai, R. Jozefowicz, and S. Bengio. Generating sentences from a continuous
space. In CoNLL, pages 10–21, 2016
Hanzy1996 commented 3 years ago

Much appreciation for your patience and detailed explanation! I have understood this strategy now.

By the way, much appreciation for your splendid work!