LTH14 / mar

PyTorch implementation of MAR+DiffLoss https://arxiv.org/abs/2406.11838
MIT License
1.05k stars 57 forks source link

Scaling diffusion MLP for learning higher dimensional VAE feature #55

Open zythenoob opened 2 months ago

zythenoob commented 2 months ago

In the paper, the dim of VAE latent is 8 or 16 and the experiments covers MLPs of 6-12 blocks. I experimented with an 8-block MLP learning 64 and 1024-dim audio data, while the model struggled to learn 1024-dim (correct sound but noisy), it performed ok on 64-dim.

Despite you mentioned that the size of MLP contributes marginally to the final performance, I wonder if when the model is learning a higher dimensional target, scaling MLP can lead to drastic change in performance? Or alternatively, what design choices for the diffusion part other than MLP may benefit high dimensional learning? Thank you!

LTH14 commented 2 months ago

Yes your observation is correct -- if the target data is too high-dimensional, it would be hard for a simpleMLP to model it. In our paper, the token dimension is always 16 (4x4 for KL-8 and 16 for KL-16). If the dimension of the target is too large (as a reference, a 16x16x3 pixel patch is 768 dimensions), you might need to design a more powerful head for the DiffLoss to model it (for example, model a 16x16x3 patch using a convnet).

MikeWangWZHL commented 2 months ago

Hi! I observe some potentially similar issues: I am trying to use this diffusion loss to model a latent space (from a custom continous tokenizer) with 128-dim (with simple AR modeling); I find that during inference sampling, the estimiation of the "pred_xstart" results in very large values; Do you have any insights on why?

LTH14 commented 2 months ago

@MikeWangWZHL I once tried 64-dim and it worked fine -- never tried 128-dim however. If pred_xstart is very large I would suggest either train for longer or use 1000 steps during inference to see whether the problem persists.

MikeWangWZHL commented 2 months ago

thanks for the quick reply; one thing I find is that the beta for the gen_diffusion at the last timestep (e.g., t=99) is 9.99989999e-01 which is not clamped to say 0.999; this contributes to a very large coefficient in the _predict_xstart_from_eps; is this expected? thanks in advance!

LTH14 commented 2 months ago

We directly adopted the diffusion code from DiT, so I didn't look into this carefully. You could try to clamp it to see whether it improves the stability.

Paulmzr commented 4 weeks ago

Hi! I observe some potentially similar issues: I am trying to use this diffusion loss to model a latent space (from a custom continous tokenizer) with 128-dim (with simple AR modeling); I find that during inference sampling, the estimiation of the "pred_xstart" results in very large values; Do you have any insights on why?

I also found it. I also tried applying the "clip" option to avoid the large x_o prediction, but the generation still failed.

darkliang commented 4 weeks ago

In my case, if x_start is very large, using ddpm inference steps = 1000 helps.