Closed tankche1 closed 2 years ago
Hi @tankche1, yeah this is normal behavior. The reason why this happens is because of how we construct the target images early in training. We gradually interpolate the latent code that synthesizes the target images from the input image’s code to the learned congealing vector over the first 150K gradient steps. This stabilizes training by letting the STN predict only small warps at the start of training. So basically the learning task gets harder for the STN in the early training steps, leading to the loss curve you’re seeing. Once you hit ~100K to 150K iterations, the STN will start making a lot of progress.
You don’t have to train for the full 1.5M steps unless you want to squeeze out every last bit of performance. From my experience, you get most of the benefit after about 400K to 712K steps. Make sure you use one of the checkpoints that gets saved when the learning rate hits zero. Those seem to be the best.
Thanks! The training looks great now. One question, I find that both the transformed sample and the target sample (truncated sample) are gradually moving from the original image to the final congealing version.
If I put the original image(e.g., a white cat) into the stn and set the final congealing output (e.g., the head of a white cat) as the target in the perceptual loss, the stn can not learn the transform.
Does this mean that the congealing algorithm is based on the gradual improvement from both the stn and the latent learning embedding? Also, is the t_ema only use for visualization?
Yeah, the reason you see the gradual transformation is a result of this gradual annealing of the target latent code. We have an ablation in Table 4 of the supplementary materials where we omit this annealing, and it drops PCK@0.1 of the cat model from 67% --> 59%, so it definitely makes a significant impact.
That's an interesting experiment you ran. I guess some images should be able to be successfully congealed without using gradual annealing (otherwise the ablation would probably be closer to 0% PCK :), but I don't have great intuition for the specific subset of images that it helps the most with.
At the end of training, t
is effectively discarded and t_ema
is the final model used for everything (that's the reason we visualize it during training, since we care about t_ema
's final performance more than t
's). It's an exponential moving average of t
's parameters over training, which is a trick that a lot of generative models (DDPMs, GANs, etc.) use to improve performance.
Btw, as an aside, when you use your trained models at test time, I would recommend using iters=3
when calling the STN (e.g., stn(x, iters=3)
). The iters
argument recursively applies the similarity STN on its own output, which helps a lot for harder datasets like LSUN. If you're using the testing scripts in the applications
folder, you can specify this from the command line with --iters 3
. The visualizations made during training all use iters=1
, so it's a lower bound on performance.
I run the following script:
and find the loss is going up and the transformed image learn almost nothing. Also, it take 1.85s/iter and need 1500000 iter which cost ~220 hour. Is that normal?