LTH14 / mar

PyTorch implementation of MAR+DiffLoss https://arxiv.org/abs/2406.11838
MIT License
978 stars 53 forks source link

Question on the Value of Training Loss for DiffuLoss with MAR and Causal Methods #20

Open bugWholesaler opened 2 months ago

bugWholesaler commented 2 months ago

Thanks for your great work ! I am currently engaged in a project that involves the DiffuLoss, and I am curious about the convergence behavior of the training loss. Specifically, I would like to know how much the training loss can ultimately converge for both the MAR and Causal methods. And which converges faster, the MAR method or the Causal method? Again,awesome work!I look forward to your reponse!

LTH14 commented 2 months ago

Thanks for your interest! Here is the 800 epochs training loss curve on ImageNet, MAR-L:

image

Similar to DiT, our training loss will never really converge: training for longer will keep improving the performance. However, at 400 epochs the performance (FID) is typically already quite good (< 2). In our experience, the MAR method achieves much better performance than the Causal method on ImageNet.

yuhuUSTC commented 2 months ago

Thanks for the great work! I got similar loss convergence curve when training the model myself. I find that the averaged diffloss quickly converges to about 0.25 and stops decreasing further. However, the generation performance continually becomes better as training going. This confuses me. On one side, the training loss seems to stop converging but the generation is increasingly better. On the other side, from my understanding, the averaged 0.25 Diffloss is still pretty large given the l2 loss implementation in diffloss. This means that the averaged per pixel distance to the Gaussion noise is 0.5, which is far from optimal. I look forward to your reponse!

LTH14 commented 2 months ago

Thanks for your interest. This repo should achieve a loss curve similar to the above one (around 0.33), but different data could result in slightly different loss values. This kind of quickly converged loss curve is commonly observed for diffusion loss (e.g. Figure 13 in DiT). However, after the initial quick decrease, it still constantly decreases as shown in the figure above, and thus keeps improving the generation performance.

The absolute loss value can be affected by many factors: tokenizer, dataset, model capacity, noise scheduling, etc. The absolute value can thus vary a lot. For example, DiT's loss is around 0.15, which means a pixel distance of around 0.4. This is because the denoising function is very hard to learn, especially when the noise level is high. Moreover, since we use a very large masking ratio (randomly sampled between 0.7 and 1.0) during MAR's training, our loss is even larger than DiT.

yuhuUSTC commented 2 months ago

Thanks for the answer.

Robootx commented 2 months ago

how about the loss of causal method

LTH14 commented 2 months ago

@Robootx here is the loss for random order causal method (with teacher-forcing language modeling loss):

image
Robootx commented 2 months ago

@Robootx here is the loss for random order causal method (with teacher-forcing language modeling loss): image

thank you very much

Robootx commented 2 months ago

Could you please show me some images generated by a causal method?

zythenoob commented 2 months ago

I wonder how the model performs at different training stages, e.g., how many training steps it takes to be able to generate the shape of an object?

zhuhr925 commented 2 months ago

Thanks for your interest! Here is the 800 epochs training loss curve on ImageNet, MAR-L: image Similar to DiT, our training loss will never really converge: training for longer will keep improving the performance. However, at 400 epochs the performance (FID) is typically already quite good (< 2). In our experience, the MAR method achieves much better performance than the Causal method on ImageNet.

Thanks for your loss curve. Can you give me the loss curve of Mar_huge, Thanks!

LTH14 commented 2 months ago

@Juhywcy

image
zhuhr925 commented 2 months ago

@Juhywcy image

thanks!

zhuhr925 commented 2 months ago

can you prepare the lr details of mar_huge training? It will help reproduce the results.

LTH14 commented 2 months ago

@Juhywcy the learning rate schedule and value is the same for all models

zhuhr925 commented 2 months ago

@Juhywcy the learning rate schedule and value is the same for all models

thanks for your reply!! constant and 1e-4?

LTH14 commented 2 months ago

@Juhywcy the learning rate schedule and value is the same for all models

thanks for your reply!! constant and 1e-4?

yes -- also with linear warmup.

zhuhr925 commented 2 months ago

@Juhywcy the learning rate schedule and value is the same for all models

thanks for your reply!! constant and 1e-4?

yes -- also with linear warmup.

thanks for your fast reply! have a good day!

poppuppy commented 1 month ago

Thanks for your great work! Can you also provide the loss when training MAR with cross entropy loss? Thank you, and look forward to your reply.

LingweiMeng commented 3 weeks ago

@Juhywcy the learning rate schedule and value is the same for all models

thanks for your reply!! constant and 1e-4?

yes -- also with linear warmup.

I just want to confirm that, in the paper, the lr is 8e-4? Why it is 1e-4 here? Thank you. :)

LTH14 commented 3 weeks ago

@LingweiMeng we scale the final learning rate according to the total batch size divided by 256. 1e-4 is the "base learning rate" before scaling.

LingweiMeng commented 3 weeks ago

Thank you.

On 14 Oct 2024, at 05:20, Tianhong Li @.***> wrote:

@LingweiMeng https://github.com/LingweiMeng we scale the final learning rate according to the total batch size divided by 256. 1e-4 is the "base learning rate" before scaling.

— Reply to this email directly, view it on GitHub https://github.com/LTH14/mar/issues/20#issuecomment-2409135622, or unsubscribe https://github.com/notifications/unsubscribe-auth/ATAO74RMSDGE6OBDZYETQYDZ3LP3DAVCNFSM6AAAAABMRSWNWCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDIMBZGEZTKNRSGI. You are receiving this because you were mentioned.