Closed OmerMachluf closed 20 hours ago
Hi, thanks for your interest. Our diffusion transformer always uses a sequence length of 256 (input is 32x32 resolution latent and patchified with a patch size of 2). A latent dimension is adjusted through MLP, so z_tilde and z should always have the same dimension. If you want to use our implementation with different resolution datasets or different model patch size, you can adjust the output shape from pretrained encoders (e.g. DINO) by using positional embedding interpolation (see here).
Thanks for the quick response! I see, that makes sense. Do you believe interpolating the projector output to the right sequence len would work as well?.
Yes - we already interpolated some encoder outputs because they use a sequence length=196.
Thanks for this great work and for releasing an implementation so quickly! I'm seeing in your loss.py that you calculate the mean between z_j, z_tilde_j with bsz = zs[0].shape[0] for i, (z, z_tilde) in enumerate(zip(zs, zs_tilde)): for j, (z_j, z_tilde_j) in enumerate(zip(z, z_tilde)): z_tilde_j = torch.nn.functional.normalize(z_tilde_j, dim=-1) z_j = torch.nn.functional.normalize(z_j, dim=-1) proj_loss += mean_flat(-(z_j z_tilde_j).sum(dim=-1)) proj_loss /= (len(zs) bsz)
You are multiplying z_j * z_tilde_j, which are driven from z,z_tilde For dinov2 z would be [B, 256, 768] shape as we're using the x_norm_patchtokens. However, given that z_tilde is eventually driven from the latent dimensions, and given our mlp projectors kept the sequence len through the projections, these shapes might not be compatible for multiplying or broadcasting. what am i missing?