LTH14 / rcg

PyTorch implementation of RCG https://arxiv.org/abs/2312.03701
MIT License
785 stars 36 forks source link

Could you please provide some training log file? #15

Closed fungtion closed 8 months ago

fungtion commented 8 months ago

When training RDM, I have no idea if the model converges, so could you please provide some log for reference?

LTH14 commented 8 months ago

Thanks for your interest. I have a similar problem, and that's a common problem in training diffusion models from my experience. As shown in this log, the training loss almost does not change from the beginning.

image

However, as shown in Table 5(c), the FID converges around 200 epochs (200k steps).

On ImageNet, training RDM for 100-200 epochs should converge. I would suggest you train it for a sufficiently long time and save checkpoints with some fixed intervals, so that you can later evaluate RDM's performance at different epochs (we use a constant learning rate for training RDM, so you don't need to perform multiple training for different epochs). Note that the pixel generator's training does not depend on the pre-trained RDM -- we include it only for evaluation. So you can train one pixel generator, and then take the trained pixel generator and combine it with different RDM to check the performance.

fungtion commented 8 months ago

It is very detailed, thanks.

mapengsen commented 4 months ago

When I trained the RDM with my own data set(1.3 million 256*256 images), my loss remained almost unchanged, hovering around 0.0065. (train 120 epoch)

Then I evaluated on the trained pixel generation and found that the generated image was not effective.

Could you help me analyze the problem? Thank you very much

image

image

LTH14 commented 4 months ago

Your training loss is unexpectedly low. On ImageNet, the training loss should be like this

image

where 1k steps correspond to 1 epoch on ImageNet. One thing that might worth checking is the pre-trained MoCo-v3: it is pre-trained on ImageNet, so it might not be able to extract meaningful representation if your data is too different from ImageNet (for example, medical data). You also need to make sure your data's range follows the ImageNet input.

mapengsen commented 4 months ago

Thank you for your enthusiastic reply. I can see that RDM loss in lines 675-692 is not simply L2 loss of target and source(https://github.com/LTH14/rcg/blob/main/rdm/models/diffusion/ddpm.py#L675). Did you use the technique of loss finding in this part?

for example this operation : loss = loss_simple / torch.exp(logvar_t) + logvar_t

I want to learn more about it.

thanks.

LTH14 commented 4 months ago

Yes, we use the p_losses function to compute the diffusion loss, which is not exactly an L2 loss (but is quite close). This is the same as the LDM implementation https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/ddpm.py#L1012. Since original_elbo_weight=0, there is no vlb loss. But the operation you mentioned above is included.