facebookresearch / DiT

Official PyTorch Implementation of "Scalable Diffusion Models with Transformers"
Other
6.37k stars 569 forks source link

when training for cfg, Why only utilize half of the input #54

Open jinge170 opened 1 year ago

jinge170 commented 1 year ago

Hello, thank you for sharing your code. when traning for classfier-frre guidance, only utilize half of the input? Can you briefly describe how that works? Tahnks!!

half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0)

marikgoldstein commented 6 months ago

Hi,

The lines of code that you mention only happen during sampling, not training.

During training, the label y is sometimes masked.

During sampling, for CFG, people combine model(xt, t, y) and model(xt, t, mask) to mix the conditional and unconditional model. This requires two model evaluations per datapoint. One way to implement that is to have a batch that looks like, for two datapoints:

[(x1, y)] [(x2, y)] [(x1, mask)] [(x2 mask)]

Then you pass all of them into the model, which computes

[x1's conditional output] [x2's conditional output] [x1's unconditional output] [x2's unconditional output]

then, you can combine the 1st and 3rd row, and the 2nd and 4th row. Finally, this produces only 2 and not 4 current sampling states, so you keep two and throw away two.