dome272 / MaskGIT-pytorch

Pytorch implementation of MaskGIT: Masked Generative Image Transformer (https://arxiv.org/pdf/2202.04200.pdf)
MIT License
398 stars 34 forks source link

Issue about generated images #6

Open wzmsltw opened 2 years ago

wzmsltw commented 2 years ago

Hi

I have also tried to re-produce the MaskGIT recently. After training 150 epoch on ImageNet, our model can only achieve 8.4% accuracy on token classification. During sampling, we find our model will generate monochrome image (nearly white). Do you meet similar problem?

dome272 commented 2 years ago

Unfortunately, I could not even train a model yet, because my compute ressources are allocated otherwise at the moment. Did you use your own codebase? If so, did you publish it somewhere?

pabloppp commented 2 years ago

@wzmsltw I'm also building a custom model inspired by this paper on the CelebA dataset, and I found something similar happens. I think in my case it's still early in the training, and I get accuracies of around ~40% which make sense since during the training the masking follows a cosine scheduling (as the paper says) and the AUC of the cosine function is around 0.363, so 40% means it works slightly better than just guessing the unmasked tokens. This actually makes sense, since we do not want a perfect 100% that will mean that the model overfits, we want to be able to sample several possible reconstructions from a masked image, so that doesn't worry me too much...

What I find strange is what happens during the sampling. At the first step, the model generates something pretty promising, while being very noisy: image

And as the sampling process continues, it would seem like the model is starting to generate something that looks like a face :D image

BUT at some point, the sampling makes the face start to fade away... image

And we end up with an almost empty image in most cases :/ looks like as I start to decrease the temperature during sampling, the samples start to collapse to an empty background or idk... image

Maybe something similar is happening to you... I don't know if this will be fixed with more training, or it's a problem in the sampling procedure... wdyt?

dome272 commented 2 years ago

Did any of you find a solution? @pabloppp @wzmsltw

pabloppp commented 2 years ago

Unfortunately no. I have the feeling that the issue is with the sampling method, but it might also be related with how the tokens are masked during training :/

dome272 commented 2 years ago

May I ask if you used my implementation or your own? If you used mine then it might also be something implementation specific. However if we both get the same white images with different implementations, then yes there might be something wrong with the overall sampling strategy.

pabloppp commented 2 years ago

I used my own implementation

LeeDoYup commented 2 years ago

I cannot still reproduce the results in the paper after 300 epochs of training on ImageNet. However, I fine that the temperature annealing is a key for the performance of MaskGIT and diversity of generated samples. When I have not used the temperature annealing, I got 32.87 FID. However, when I have used the temperature annealing, which linearly decrease the temperature of logits from 3.0 into 1.0, I got 20.26.

When I train MaskGIT on FFHQ, very simple (but high quality) images are generated. However, due to the simplicity of generated images, the recall of the trained model is very low and FID is over 100.

I think that many tricks are required to train MaskGit, but the details are not described in the paper. Especially, temperature annealing is very very important trick to decrease FID, but the authors did not describe the details.. How can I believe the scores in the original paper...

dome272 commented 2 years ago

Oh that is really interesting with the temperature annealing. I completely over-read it. Without the annealing sampling looks like this: image And with annealing going from 3 to 2 it looks like this: image

(Trained just for 28 epochs on landscape flickr dataset)

pabloppp commented 2 years ago

The samples I shared used temperature annealing as well, but I still don't get very good results.

dome272 commented 2 years ago

@pabloppp @LeeDoYup Maybe if you are interested we can make a group on discord and report new findings. I would also be interested in your Transformer implementation. I guess mine is so simplistic. So if you are up to it you can add me on discord: dome#8231

Also the authors are referencing BEiT which uses a slightly different way of training. Even though the authors clearly described their way of training, maybe using the approach from BEiT could result in improvements. Have you tried anything like this?

pabloppp commented 2 years ago

I have not tried anything like BEiT, in fact, my architecture is pretty different from the one proposed in the paper. What I tried to follow as close as possible were the losses, training schedules & sampling schedules.

LeeDoYup commented 2 years ago

I do not have discord account, so i will try to create soon. @pabloppp What was your setting of temperature annealing? In my case, on FFHQ, start_temperature is 3.0 or 5.0 and end_temperature is 1.0

pabloppp commented 2 years ago

I have it as a parameter in my sampling function, I also tried different schedules for the annealing (linear, cosine, etc...), with very similar results, so I don't exactly remember the parameters used in the samples that I shared above, but for example, in this samples, I used a cosine decaying temperature (same % as the sampling token numbers) going from 1.5 to 0.3, for 16 steps.

Here I'm showing the first output, half-schedule and final output image image image

pabloppp commented 2 years ago

It shouldn't be relevant, but my model is conditional, so I add an identity embedding to the input. My goal was to be able to control to some degree the generation, so I can ask the model to generate a specific face instead of just random + help the model since conditional generation is usually way easier for generative models.

The model seems to be able to use that information up to some point, like generating male/female faces depending on the reference image but does a lot of random generation as well.

(Bottom image is a conditional sampling from scratch) image

pabloppp commented 2 years ago

Hello! Small update: I just tried adding typical filtering to the sampling code, and the results are still far from perfect, but I managed to pass from a very high % of just plain colored images to a considerable % of face-ish results :D

Here's an example without typical filtering: image

And here a couple of examples with typical filtering (with a mass of 0.2): image image

For the filtering I just adapted the code from the official repo: https://github.com/cimeister/typical-sampling/blob/3e676cfd88fa2e6a24f2bdc6f9f07fddb87827c2/src/transformers/generation_logits_process.py#L242-L272

Seems like, although the 'typical filter' is made to try to follow some rules about how language works, by allowing the model to pick from a large number of options when the expected information is high while reducing the pool of options when the expected information is low, it seems to also benefit image generation. I think it might even be related to the non-sequential nature of the sampling, so at the beginning when sampling the first pixels, the expected information is pretty high, so the model can pick a wider variety of options, while as the image starts taking shape the options are reduced since we already have a general sketch of the image... Or idk, it might be something completely different XD

Anyway, hopefully this is useful for someone, and maybe we could even reach @cimeister to ask if they thought of this for image generation 🤔

LeeDoYup commented 2 years ago

@pabloppp Oh, did you use the typical sampling in the process of multinomial sampling to predict the code of each position? I think it would be help increasing diversity, since the typical sampling is known to resolve de-generation problem in NLG.

pabloppp commented 2 years ago

Yes, basically before calling multinomial sampling I do what the TypicalLogitsWarper function does to set the logits of the filtered tokens to -inf so the multinomial only samples from the filtered pool. I also keep the temperature decay and the sampling schedule for the number of tokens sampled each step untouched.

cimeister commented 2 years ago

Wow very cool! Thanks for sharing, @pabloppp. We hadn't tried typical sampling yet for image generation but it seems like a promising direction!

LeeDoYup commented 2 years ago

The paper describes as follow.

In practice, the masking tokens are randomly sampled with temperature annealing to encourage more diversity, and we will discuss its effect in 4.4.

Here, I am confusing whether they use temperature annealing (TA) to randomly select the masking position, or use TA in the multinomial sampling in each position. When I randomly select the masking position, I can got a large performance improvement, although the performance in the paper is not reproduced.

pabloppp commented 2 years ago

@LeeDoYup I'm pretty sure temperature is applied before softmax logits, thus affecting the multinomial sampling 🤔 but things that you mention are correlated: you change the temperature, so the probability of sampling some tokens varies, then you apply the multinomial and keep only a number of tokens based on their score following the cosine schedule.

LeeDoYup commented 2 years ago

@pabloppp When I logically think about the mask selection based on the algorithm, I also agree that the TA is used before softmax logits. However, when I only read the sentence above, the sentence means that they made a randomness on "mask selection" not on token sampling. So I am very confused.

When I use the random masking strategy, the performance on ImageNet is much improved. For examle, when n_decoding_step=8, I got precision=0.63 and recall=0.36 with linear temperatue annealing (5.0 => 1.0). However, I got precision=0.64 and recall=0.58 when I randomly select the unmasking tokens, and fix the temperature=0.8 over all decoding steps. I think, the hyper-parameter tuning of MaskGIT is very exhaustive ..

dome272 commented 2 years ago

@LeeDoYup can you show some pseudo-code example for both of the cases you describe above? Im a bit confused which is which. What excatly do you mean with random masking strategy? What I have implemented is that I choose the indices which have the highest probability. Are you just randomly selecting these tokens?

LeeDoYup commented 2 years ago

@dome272 I use the random masking strategy as follow:

if strategy == 'random':
    candidates = masked_idxs_sorted
    subset = torch.randperm(candidates.shape[0])[:n_newly_unmasked]
    newly_unmasked_idxs = masked_idxs_sorted[subset]

masked_idxs_sorted => the indexes, which are masked at each decoding time, sorted with their confidence n_newly_unmasked=> # of tokens to additionally unmask at each decoding time.

That is, I randomly select the positions of unmasked tokens.

pabloppp commented 2 years ago

In some way, it makes sense to choose them randomly since during training you're masking them randomly, so the model is used to having to reconstruct random missing tokens, not necessarily start with the highest scores and end with the ones with lowest score. But that said, I the paper they very explicitly say that they sample the highest scored tokens.

Captura de pantalla 2022-03-29 a las 10 43 29

At each iteration, the model predicts all tokens simultaneously but only keeps the most confident ones.

I'm pretty sure when they say In practice, the masking tokens are randomly sampled with temperature annealing to encourage more diversity they mean that they don't just pick the token with the highest score, but they sample from a multinomial distribution adjusted with the temperature. This multinomial sampling introduces a randomness factor (the higher the temperature, the more randomness).

Anyway, if you've found that sampling randomly instead of taking the highest scores helps, it's worth a try 🙇

BTW what are you doing exactly to get recall & precision from a random sampling? 🤔 What do you compare your output to in order to get metrics?

LeeDoYup commented 2 years ago

@pabloppp I totally agree with you. However, the most problematic fact is that the hyper-parameters of TA is not described in the paper and hard to reproduce the results.

When I evaluate the recall & precision on ImageNet, I generated 50K samples and use the protocol in this repository.

dome272 commented 2 years ago

@pabloppp Could you show a code sample where you put the typical sampling?

  logits /= temperature
  filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
  probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
  sample = torch.multinomial(probs, 1)

Thats the normal topk_topp sampling. At which point do you call the TypicalLogitsWarper and with which input?

pabloppp commented 2 years ago

I don't think you're doing it right. You're supposed to first sample for every masked token, then pick the topK with highest scores, since otherwise, you don't really know the score of the token you sampled.

I do not use directly the TypicalLogitsWarper class, but I use the same implementation. This is the core of my sampling implementation.

logits, _ = self(x, c, mask)
probs = logits.div(temp)
probs_flat = probs.permute(0, 2, 3, 1).reshape(-1, probs.size(1))
if typical_filtering:
    probs_flat_norm = torch.nn.functional.log_softmax(probs_flat, dim=-1)
    probs_flat_norm_p = torch.exp(probs_flat_norm)
    entropy = -(probs_flat_norm * probs_flat_norm_p).nansum(-1, keepdim=True)

    probs_flat_shifted = torch.abs((-probs_flat_norm) - entropy)
    probs_flat_sorted, probs_flat_indices = torch.sort(probs_flat_shifted, descending=False)
    probs_flat_cumsum = probs_flat.gather(-1, probs_flat_indices).softmax(dim=-1).cumsum(dim=-1)

    last_ind = (probs_flat_cumsum < typical_mass).sum(dim=-1)
    sorted_indices_to_remove = probs_flat_sorted > probs_flat_sorted.gather(1, last_ind.view(-1, 1))
    if typical_min_tokens > 1:
        sorted_indices_to_remove[..., :typical_min_tokens] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(1, probs_flat_indices, sorted_indices_to_remove)
    probs_flat = probs_flat.masked_fill(indices_to_remove, -float("Inf"))
probs_flat = probs_flat.softmax(dim=-1)
sample_indices = torch.multinomial(probs_flat, num_samples=1)
sample_scores = torch.gather(probs_flat, 1, sample_indices)
dome272 commented 2 years ago

hey guys. today I reached out the authors if they would help us in our problem and the first author replied to me and said that they are planning to release the code next week (or so). And since typical sampling is probably not used in their paper, using it will probably give an even higher boost in performance.

pabloppp commented 2 years ago

Guess the official repo is out: https://github.com/google-research/maskgit (although it seems to be in JAX) Let's find out (and share here please 🙏 ) what was missing in our implementations 🙇

dome272 commented 2 years ago

Yea please report if anyone finds the cliffhanger....

pabloppp commented 2 years ago

Small finding: Seems like @LeeDoYup was on the right track https://github.com/google-research/maskgit/blob/cf615d448642942ddebaa7af1d1ed06a05720a91/maskgit/libml/parallel_decode.py#L127 They sample completely randomly and then choose the ones that have the highest probability, but they add noise to the probabilities, which is linearly decayed.

So, step 0 just basically samples randomly a token, then the next step is also random but less random, etc... :/

I guess we should try this, but it seems like a lot of randomness XD

LeeDoYup commented 2 years ago

Yes, when I see https://github.com/google-research/maskgit/blob/cf615d448642942ddebaa7af1d1ed06a05720a91/maskgit/libml/parallel_decode.py#L49-L56, i conclude that the paper was wrong. Mixing randomness is the key of the algorithm.....

dome272 commented 2 years ago

If someone of you translated the sampling code to pytorch, could you post it here?

pabloppp commented 2 years ago

Hi, just wanted to share some small update, and see if anyone here was getting new interesting results. My model conditioned in facial identity keeps getting better the longer I train it. It's still veeeeery far from what it should be able to produce based on the paper, but it feels like the training is very very very slow and requires a lot of iterations (I'm doing my experiment on Google Colab soooo... not ideal XD). I tried to replicate the official sampling schedule as close as possible, but ended up doing some tweaks because I feel they work slightly better in my case:

image

If I sample multiple images using the same id embedding I get something like this, the identity is not very well preserved, but there are clearly common traits that the model is able to reproduce.

Captura de pantalla 2022-04-17 a las 22 03 53

I started doing some tests on a V100 using a much larger, more diverse unlabeled dataset of nature images, in a sort of autoencoder-ish way: image embedding in -> image out, but I just managed to get noise. Maybe it's just a matter of training longer on more powerful machines? Maybe there's some way to accelerate the training to avoid having to do the equivalent of 300 epochs on Imagenet as they claim in the paper? Idk...

Has anyone else managed to get something?

GuoxingY commented 2 years ago

Thanks for your share. I have also tried to train a transformer based model on COCO dataset for image generation, but got worse results. Could you share some nature images generated by your model trained on the dataset of nature images? I wonder if the training iteration is the key point to train a well model or I miss some details in my implementation.

pabloppp commented 2 years ago

Sure. Depending on how I do the sampling (more/less temperature), I get stuff like this two things (this is about 900k iterations of batch size 6, on a dataset with 270k images)

Captura de pantalla 2022-04-18 a las 23 16 33

image

I feel like the model is starting to learn what tokens are more common in this sort of image, but its still very random, and produces a "not completely random" noise.

With 300 epochs of Imagenet, the model sees 4200000000 images, so my training is only about 0.12% of what they did in the paper, and still took me several days on a V100 😓

LeeDoYup commented 2 years ago

I finish to train 200M params of model on ImageNet during 300 epochs, but when I use the released technique, I got FID=21. When I use temperature scale in predicting tokens (not mask), I got FID=11~12, which is not reproduced result.

By the way, I think they do not use temperature annealing to randomly select the position of un-masking, since the below code fixes the temperature parameters as a scalar (=4.5). Is it right...? https://github.com/google-research/maskgit/blob/cf615d448642942ddebaa7af1d1ed06a05720a91/maskgit/libml/parallel_decode.py#L158-L159

pabloppp commented 2 years ago

It's 4.5 * (1 - ratio) and ratio goes up from 0 to 1 with the sampling step, so the temperature gets annealed to 0

LeeDoYup commented 2 years ago

@pabloppp Oh, thank you I will try it !

GuoxingY commented 2 years ago

@pabloppp so maybe training the model longer can get a better results like @LeeDoYup shared? Although the quantitative result is not as good as the paper stated, but I think the quality of generated images with FID=21 is much better than images we got. Could you share any generated images here @LeeDoYup ?

LeeDoYup commented 2 years ago

@GuoxingY The images are generated images of ImageNet. I will share in this thread soon.

dome272 commented 2 years ago

I trained the the described transformer (172M params) for 1000 epochs on a dataset with 8k landscape images with batch_size of 100. Loss: image

Samples: (1. original, 2. reconstructed, 3. Inpainted bottom half, 4. New sampled image) image image image image

More samples: image image image image image image image image

Ill update my code and upload some checkpoints soon if anyone is interested. Tried to follow the paper as close as possible. Note: the samples are somewhat cherrypicked.

pabloppp commented 2 years ago

How much of the image do you mask for the reconstructed image?

dome272 commented 2 years ago

How much of the image do you mask for the reconstructed image?

The formulation might have been a bit misleading in that context. The reconstruction is just encoding and decoding the image and has nothing to do with the transformer. I just put it in for my own better understanding.

zhuqiangLu commented 1 year ago

I trained the the described transformer (172M params) for 1000 epochs on a dataset with 8k landscape images with batch_size of 100. Loss: image

Samples: (1. original, 2. reconstructed, 3. Inpainted bottom half, 4. New sampled image) image image image image

More samples: image image image image image image image image

Ill update my code and upload some checkpoints soon if anyone is interested. Tried to follow the paper as close as possible. Note: the samples are somewhat cherrypicked.

Hi, has the code been updated?