sihyun-yu / REPA

Official Pytorch Implementation of Representation Alignment for Generation: Training Diffusion Transformers Is Easier Than You Think
https://sihyun.me/REPA
MIT License
437 stars 15 forks source link

How does the projection to the encoder space work? #7

Closed OmerMachluf closed 20 hours ago

OmerMachluf commented 20 hours ago

Thanks for this great work and for releasing an implementation so quickly! I'm seeing in your loss.py that you calculate the mean between z_j, z_tilde_j with bsz = zs[0].shape[0] for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)): for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)): z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) z_j = torch.nn.functional.normalize(z_j, dim=-1) proj_loss += mean_flat(-(z_j z_tilde_j).sum(dim=-1)) proj_loss /= (len(zs) bsz)

You are multiplying z_j * z_tilde_j, which are driven from z,z_tilde For dinov2 z would be [B, 256, 768] shape as we're using the x_norm_patchtokens. However, given that z_tilde is eventually driven from the latent dimensions, and given our mlp projectors kept the sequence len through the projections, these shapes might not be compatible for multiplying or broadcasting. what am i missing?

sihyun-yu commented 20 hours ago

Hi, thanks for your interest. Our diffusion transformer always uses a sequence length of 256 (input is 32x32 resolution latent and patchified with a patch size of 2). A latent dimension is adjusted through MLP, so z_tilde and z should always have the same dimension. If you want to use our implementation with different resolution datasets or different model patch size, you can adjust the output shape from pretrained encoders (e.g. DINO) by using positional embedding interpolation (see here).

OmerMachluf commented 11 hours ago

Thanks for the quick response! I see, that makes sense. Do you believe interpolating the projector output to the right sequence len would work as well?.

sihyun-yu commented 11 hours ago

Yes - we already interpolated some encoder outputs because they use a sequence length=196.