eladrich / pix2vertex.pytorch

An official pyTorch port of the pix2vertex paper from ICCV2017
MIT License
192 stars 24 forks source link

Elaboration on how unlabelled real images were used during training #3

Closed vikasTmz closed 4 years ago

vikasTmz commented 4 years ago

The released network was trained on a combination of synthetic images and unlabelled real images for some extra robustness

As this uses the pix2pix pipeline, I'm curious to how unlabelled real images were incorporated during training. Did you fit the synthetic dataset to some real life face dataset like CelebA as done by Sengupta et al. in SfSNet ?

Edit: It also looks like your new model is more robust to occlusions than the previous one.

eladrich commented 4 years ago

Sort of, I'm planning to write a full explanation soon, but just to give the general idea 💡. Putting unsupervised training aside, methods mainly use one of the following:

When comparing the two data domains we consistently saw that training with synthetic data lead to more accurate reconstruction of facial details, but resulted in noise around the edges and on occlusions.

To try and solve that we be played a bit with generating more realistic synthetic data. Specifically we placed the synthetic faces on real people and blending occlusions back. That way we can still utilize the accuracy of synthetic data with more realistic images. Screenshot_2020-05-11-0859-05-852_com google android apps docs

While using the blended data resulted in some additional robustness, we found that it does not fully mimic real data and still results in some artifacts. This made us decide to try and directly utilize the real data to get extra robustness to those patterns.

While using rendering losses is a great way to learn geometry from real data without labels, we realized what we really want from real data is only to learn the face region and not the geometry. To achieve that we simply added an auxiliary task of predicting a segmentation map for the head region (0, 1 values). While the geometry maps were trained only from synthetic images, the segmentation task was trained on both synthetic and real data fitted with a simple 3DMM, allowing the model to see many real images during training. This simple scheme was able to reduce most of the artifacts we have encountered 🙃.

IMG-20200511-WA0000 Screenshot_2020-05-11-08-56-52-677_com google android apps docs

vikasTmz commented 4 years ago

Ah, neat! Thanks for the detailed explanation.