CompVis / taming-transformers

Taming Transformers for High-Resolution Image Synthesis
https://arxiv.org/abs/2012.09841
MIT License
5.69k stars 1.13k forks source link

how do we guarantee a reasonable conditional generation when training transformer? #101

Open lukun199 opened 3 years ago

lukun199 commented 3 years ago

Hello, Thanks for the awesome code. I meet a problem when trying to understand how the transformer learns in the third stage.

In the segmentation and depth-conditioned generation tasks, we train the transformer using F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) in https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/cond_transformer.py#L286, where target and logits are defined in https://github.com/CompVis/taming-transformers/blob/9d17ea64b820f7633ea6b8823e1f78729447cb57/taming/models/cond_transformer.py#L90-L104. So, we learn z_indices from cz_indices = torch.cat((c_indices, a_indices), dim=1). I just wonder why the network will not collapse to just momorize the z_indices?

I find in the colab notebook that even when randomly choosing the z_indices, the model could still behave well with a proper c_indices (in that case, c_indices comes from the segmentation mask). But I am just curious how the model learns under a relatively weak supervision?

IceClear commented 2 years ago

@lukun199 Hi, have you figured that out? I am also curious about this part of code.

ZhuXiyue commented 1 year ago

The GPT tries to predict input[i] based on input[:i-1] without looking at input[I]. During test it will try to predict ith output by using condition and prediction[:i-1].