willisma / SiT

Official PyTorch Implementation of "SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers"
https://scalable-interpolant.github.io/
MIT License
603 stars 29 forks source link

Some questions about forward_with_cfg #23

Open maxin-cn opened 5 days ago

maxin-cn commented 5 days ago

Thank you for the fantastic work and for making the code open-source. I have a small question regarding the forward_with_cfg function.

def forward_with_cfg(self, x, t, y, cfg_scale):
      """
      Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
      """
      # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
      half = x[: len(x) // 2]
      combined = torch.cat([half, half], dim=0)
      model_out = self.forward(combined, t, y)
      # For exact reproducibility reasons, we apply classifier-free guidance on only
      # three channels by default. The standard approach to cfg applies it to all channels.
      # This can be done by uncommenting the following line and commenting-out the line following that.
      # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
      eps, rest = model_out[:, :3], model_out[:, 3:]
      cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
      half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
      eps = torch.cat([half_eps, half_eps], dim=0)
      return torch.cat([eps, rest], dim=1)

Why do we need these two lines:

half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)

instead of simply setting combined = x:

def forward_with_cfg(self, x, t, y, cfg_scale):
      """
      Forward pass of SiT, but also batches the unconSiTional forward pass for classifier-free guidance.
      """
      # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
      combined = x
      model_out = self.forward(combined, t, y)
      # For exact reproducibility reasons, we apply classifier-free guidance on only
      # three channels by default. The standard approach to cfg applies it to all channels.
      # This can be done by uncommenting the following line and commenting-out the line following that.
      # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
      eps, rest = model_out[:, :3], model_out[:, 3:]
      cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
      half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
      eps = torch.cat([half_eps, half_eps], dim=0)
      return torch.cat([eps, rest], dim=1)

When I made this change, the model was no longer able to generate images correctly. As I understand it, the two approaches:

half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)

and combined = x should be equivalent. Could you please help me understand where the issue arises when changing it to combined = x?

Thank you very much!

xmhGit commented 3 days ago

In my case, it's totally the same results.