marcoamonteiro / pi-GAN

416 stars 76 forks source link

Puzzled by Head Position #21

Closed RaymondJiangkw closed 2 years ago

RaymondJiangkw commented 3 years ago

Hi, @ericryanchan,

I am really curious about how you solve the head position problem. I see that real images are not paired with ground-truth head positions, thus, the network learns the head position in an unsupervised way.

After checking your code, I find that in every turn, you sample the head position, and let the discriminator to predict the head position of the rendered faces. The output of discriminator is corrected by the sampled head position in a self-supervised way.

I am really puzzled by this mechanism, and fail to figure out why this works. Can you help me?

ericryanchan commented 3 years ago

Hi Raymond! We landed on this pose correction loss because we were looking for a way to force all heads to have the same canonical pose. We wanted it so that when you asked for a front-facing image, you got a front-facing image, no matter the identity! This pose correction loss is entirely optional; if you train with sufficiently low learning rates it also seems to help stabilize the head poses. However, we noticed that when we tried to up the convergence, we'd encounter pose drift.

So this pose correction loss works in two parts. In part one, we teach the discriminator to recognize poses. Whenever we feed the discriminator a generated image, we have it predict the pose and we simply minimize this error. If all is going to plan, the discriminator will be able to tell us the pose the image appears to have.

In part two, we update the generator based both on the quality of the image (as you do in standard GAN training) but also on how closely it appeared to align with the correct pose. We have the ground truth pose that's fed to the generator. After part one, we now have a trained discriminator that can tell us the pose the image appears to have. Part two is just minimizing this distance, i.e. making sure our generated images' actual poses match the poses that they appear to have.

The nice thing about this strategy is that it doesn't require any actual poses from the real images and yet it still seems to stop (or at least slow) pose drift.

Hope this helps!

Eric

RaymondJiangkw commented 3 years ago

Hi, Eric! Thanks so much for your response! It really clarifies the mechanism given SIREN has learned a good canonical 3D representation for the head.

However, I am still curious about the mechanism of this strategy at the early stage of training. Specifically, how does a Canonical Head "emerge" in the SIREN?

At the very beginning, the SIREN only generates some noises, thus, clearly, it doesn't make sense to teach discriminator to predict position for noise. As the iteration goes by, because of the GAN loss, the SIREN gradually learns to generate coarse, blur, yet promising head image by volume rendering and the discriminator gradually learns to "assign" a position for these generated head images.

Suppose these premature numerically frontal head images are semantically tilted at first (when v=pi/2 and h=pi/2), I believe the discriminator will fail to correct them into semantically frontal faces, since it should have built the connection between the semantically tilted faces and numerical frontal values (v=pi/2 and h=pi/2). However, in the practice, it isn't the case. The semantically frontal faces match with numerically frontal faces only with some small errors which disappear as the training goes.

After checking the whole pipeline, I think it may related to the position sampling strategy and the dataset, since the sampled vertical and horizontal positions are centered at pi/2 and the dataset (say "CelebA") has the property that most of its faces are frontal or close to frontal.

Thus, at the early stage of training, when we feed real images into discriminator, it will think frontal faces are truer than tilted ones and the SIREN is encouraged to generate frontal or near-frontal faces at sampled positions. Since sampled positions are centered around pi/2, SIREN will gradually learn to construct frontal faces, i.e. canonical faces, at the position pi/2. However, since we are also actively sampling other positions and require SIREN to generate truthful images at different views, a 3D model for head is formed instead of a 3D plane.

This is just my premature thought. Can you tell me whether it is correct or not? If not, can you solve my puzzle for the reason why a Canonical Head is guaranteed to "emerge" in the SIREN at the early stage instead of a Tilted Head or 3D plane, i.e. mirror? Thanks a lot!