Open jinge170 opened 1 year ago
Hi,
The lines of code that you mention only happen during sampling, not training.
During training, the label y is sometimes masked.
During sampling, for CFG, people combine model(xt, t, y) and model(xt, t, mask) to mix the conditional and unconditional model. This requires two model evaluations per datapoint. One way to implement that is to have a batch that looks like, for two datapoints:
[(x1, y)] [(x2, y)] [(x1, mask)] [(x2 mask)]
Then you pass all of them into the model, which computes
[x1's conditional output] [x2's conditional output] [x1's unconditional output] [x2's unconditional output]
then, you can combine the 1st and 3rd row, and the 2nd and 4th row. Finally, this produces only 2 and not 4 current sampling states, so you keep two and throw away two.
Hello, thank you for sharing your code. when traning for classfier-frre guidance, only utilize half of the input? Can you briefly describe how that works? Tahnks!!
half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0)