nv-tlabs / LION

Latent Point Diffusion Models for 3D Shape Generation
Other
735 stars 57 forks source link

How do I know when to stop Stage 2 training? #29

Open yjcaimeow opened 1 year ago

yjcaimeow commented 1 year ago

Hi @ZENGXH ,

Is there any indicator to judge whether stage2 is trained enough? I used a different data set, I don't know what epoch/iteration to stop.

Best regards, Yingjie

ZENGXH commented 1 year ago

Hi,

If you want a strict indicator, you may need to save the model every ~2k epoch, and run the evaluation code on each saved model to get the 1-NNA score. When the 1-NNA score stop improving, the model is trained enough. Notice that there will be some noise in the 1-NNA score, you will see it fluctuating even when converged (this is normal).

Another (none strict) approach is visually inspecting the sampled images and see if the samples stop improving.

My experience is the model usually start converging at around 6k-10k epoch. And I usually train it up to 20k to make sure it is fully converged.

albertotono commented 1 year ago

Would you recommend to run the 1-NNA during training? And in general how long does it take to run for you the 1-NNA evaluation?

[This information is valid only for the diffusion part. 2nd step of the training]

ZENGXH commented 1 year ago

Would you recommend to run the 1-NNA during training? And in general how long does it take to run for you the 1-NNA evaluation?

[This information is valid only for the diffusion part. 2nd step of the training]

running the full 1-NNA evaluation on ~600 samples takes ~1hour on single gpu. Most of the time is spending on evaluating the EMD score and then is the sampling. If you want to eval 1-NNA score during training, it would be better to 1) use 1-NNA-CD only and turn off 1-NNA-EMD, since computing CD is fast; and perhaps 2) reduce the number of samples and the number of reference, i.e., create a smaller evaluation set, but it will comes with higher variance in the results. 3) use DDIM sampler to generate samples, this will cause some reduction in eval performance