wpeebles / gangealing

Official PyTorch Implementation of "GAN-Supervised Dense Visual Alignment" (CVPR 2022 Oral, Best Paper Finalist)
https://www.wpeebles.com/gangealing
BSD 2-Clause "Simplified" License
1.01k stars 120 forks source link

model synthesizes features to fool the loss #28

Closed petercmh01 closed 2 years ago

petercmh01 commented 2 years ago

Hello:

I've been using the model on my own custom dataset for a while. When I visualize the congealing process on test set and the propagated dense tracking, I noticed:

  1. the congealing process synthesize features (i.e. create animal's head on its tail side), instead of rotate or flip the image to do the correct alignment
  2. on the dense tracking, color scale will flip (i.e. flip by head to tail on animals), which I think corresponding to the last point

I read the paper about using flow smoothness and flip to avoid this issue and I understand this can occur a lot. How exactly are flip and flow smoothness helping avoid this issue? What parameter can I adjust to make my model more robust? Does the model improve on this issue after 1 million epochs ? For resources reason I haven't been able to run to 1 million epochs yet but I read the other issue post that you mention it usually takes that long for the model to improve.

I have try the default setting script for both 1, 2 head and 4 head. Then I also tried increase the inject and flow_size parameter. I also tried turn on the sample_from_full_resolution option, but haven't got any good progress from these trial yet.

Thanks in advance and it's really appriciated that you're consistently helping out : )

petercmh01 commented 2 years ago

also I wonder when target pose / representation is learned early in the training process, does it still improve / change later ?

wpeebles commented 2 years ago

Hi @petercmh01. Yeah, this is a failure mode that can pop-up during training. It's a result of the second STN generating high frequency flow fields to fool the perceptual loss. Here are a few tips that might be helpful for mitigating it.

(1) We have two regularizers that try to keep the highly-expressive flow STN under control: --tv_weight and --flow_identity_weight. In my experience, the single best trick to combat the issue you're seeing is increasing the weight on the total variation regularizer (--tv_weight). When this value is set high, it really punishes the model for generating these high frequencies in the flow field. You might want to try increasing this value very aggressively to like 5000 or 10000 until the problem is mostly resolved. The downside of increasing it is that it will reduce the expressiveness of the flow STN (more work will be done by the similarity STN, which can only do coarse alignment), so it's a balancing act.

(2) As a more extreme version of (1), you can also try training a similarity-only STN and see how performance is there. You won't get a super accurate alignment, but it might be a good sanity check to make sure you get an accurate coarse alignment before adding on the flow STN.

(3) Training with reflection padding (--padding_mode reflection) can be less prone to issues than border padding.

(4) A good sanity check is looking at the visuals created for the learned target mode (template) during training. You want to make sure that whatever StyleGAN latent mode is being learned is "reachable" from most images in your dataset. For example, if you trained a unimodal GANgealing model on LSUN Horses, you might get a side-profile of a horse as the discovered latent template. It will be virtually impossible to align, e.g., images of the front of horses to this mode, and you'll get bad high frequency flows for those images instead. If the learned mode seems "bad" or unreasonable, then you can try to increase --ndirs a small amount to make the search space more expressive. Increasing --num_heads can sometimes also help (e.g., LSUN Cars with num_heads=1 works a lot worse than num_heads=4 in my experience) but it's not immune to this problem.

(5) One thing to keep in mind is that there will always be some images for which this will happen even if the model is generally quite good. For example, our LSUN Cats model generates high frequency flows fields when it gets an image without a cat's upper-body visible or with significant out-of-plane rotation of the cat's face (we show some examples of this in the last figure in the paper supplementary materials). The hope of course is that these failures by-and-large should only occur for "unalignable" images.

The above tips are all for training. There are a couple of things you can do at test-time that somewhat mitigate this problem, but in my experience most of the mileage comes from adjusting training hyperparameters.

We use a technique at test time called "recursive alignment" which is important to get the best results on hard datasets (LSUN). We talk about this a bit in our supplementary, but the idea is really simple: you just repeatedly feed the aligned image output by the similarity STN back into itself N times before forwarding to the flow STN (we set N=3 for LSUN). This can be done in the code simply by doing stn(x, iters=3) (or you can add --iters 3 on the command line if you're using your trained model with one of the scripts in the applications folder). All of the visuals generated during training use N=1. You commonly see the high frequency flows appear in the flow STN when the similarity STN has failed to coarsely align the image, and increasing N at test time can often help resolve this problem "for free." That said, if you're seeing high frequency flows for every single input GAN image during training, then this probably won't resolve the issue.

Our flipping algorithm takes advantage of the tendency of the STN to "fail loudly" (i.e., it produces high frequency flows for both failure case images as well as unalignable images); basically, you can probe the flow STN with x and flip(x), and see which of the two yields the smoother flows. In general, a smooth flow --> no/few high frequencies --> STN is doing its job correctly without "cheating." Flipping is a test-time only operation that can significantly help address the issue without re-training (training visuals don't use flipping). But again, if all of your images are exhibiting high frequency flows, this won't solve the issue by itself. Also, it only helps models with asymmetric learned templates. So for example, flipping helps with LSUN Bicycles and CUB, but it won't really do anything for LSUN Cats and CelebA.

Regarding training length, you almost certainly don't need to train beyond a million iterations unless you're trying to squeeze-out every last drop of performance. I have seen high frequency flows get a bit better for "hard" images at around the ~700K iteration mark, but you can usually get a reasonable sense how training is going by ~200K iterations, and the remaining hard images seem to work themselves out by ~700K.

For your question about the target latent, in my experience it converges very rapidly (usually after like 150K to 200K iterations) and doesn't change much afterwards. We continue training it jointly with the STN for the entirety of training out of simplicity, but I'm pretty sure you could get away with freezing it after some point if you wanted to. You'd probably get a nice training speed-up from this since backpropagating through the generator is quite costly.

wpeebles commented 2 years ago

The --inject hyperparameter basically trades-off alignment precision for ease of optimization. If you have a small value for --inject, then optimization will be easy but alignment won't be precise. For 256x256 StyleGAN models, --inject 5 seems to work pretty well in general in my experiments. If you're using a lower resolution StyleGAN, you might want to decrease it (and similarly, if you're using a higher resolution model you probably want to increase it). My intuition is that decreasing --inject should make it less likely to get high frequency flows.

I never tried changing --flow_size, and I think the --sample_from_full_res argument likely won't affect things that much in general.

petercmh01 commented 2 years ago

thanks a lot, I will give it a try! How do I use a similarity-only STN? Is it done by changing the --transform argument to just 'similarity' with no flow?

wpeebles commented 2 years ago

Yeah that should do the trick. Also make sure to set --tv_weight 0 (you might also want to set --flow_identity_weight 0 too).

petercmh01 commented 2 years ago

thanks a lot! I have got very robust coarse alignment on my dataset now by using tv_weight of 40000. For the mask of dense tracking should I use average transform image or truncated image ?

wpeebles commented 2 years ago

Awesome! I made all the masks using the average transformed image. You'll probably get similar results using the average truncated image though.

petercmh01 commented 2 years ago

hey @wpeebles now I'm and trying to slowly approach to more accurate alignment base on previous parameters with good coarse alignment. I think my latent target mode is "reachable" but is missing a bit of features (for example, saying I try to align cat but the target mode only has its face and is missing the cat's ear, or I try to align bird but it's missing the bird's head ). Will slightly adjust the --ndirs or --inject parameter be helpful? I've seen the inject parameter impact it a bit from my previous experiment

Thank you

wpeebles commented 2 years ago

It's possible changingndirsand inject could give you a different template. But if you have a specific type of template in mind, it might be better to manually specify the latent instead of learning it. For example, you could generate a bunch of random samples from your StyleGAN, and record the W-Space vectors that generated each sample. Then, you can initialize the learned latent as the W vector that generated the image closest to what you're looking for. Alternatively, you could find a real image that represents the template you're looking for and run a GAN inversion algorithm on it to find a good W vector. Given the W vector, I think you can manually assign it by adding the following snippet right after line 235 of train.py:

# my_w_vector should have shape (1, 512) (assuming your StyleGAN's W-Space is 512d)
encoded_w = pca.encode(my_w_vector)
ll.assign_coefficients(encoded_w)

If you try either of these methods, you'll want to add --ndirs 512 (or whatever the dimensionality of your StyleGAN's W-Space is) and --freeze_ll when training.

petercmh01 commented 2 years ago

just to follow up, I tried this gan inversion method https://github.com/abdulium/gan-inversion-stylegan2/blob/main/Task2_GANInversion_YOURNAMEHERE.ipynb and I have got my desire target mode. Thanks for you help again!