mobaidoctor / med-ddpm

GNU General Public License v3.0
133 stars 15 forks source link

The quality of the generated images fluctuates with training #23

Closed jngw21 closed 3 weeks ago

jngw21 commented 5 months ago

Hi,

I currently have a dataset of 400 images, each with a corresponding mask of dimensions 128x128x64. I set the batch size to 8, used a learning rate of 1e-5, and ran the training for 100,000 iterations, maintaining all other parameters consistent with the original settings. During training, I observed that starting from around 20,000 iterations, the network occasionally produces high-quality samples. However, the quality of the output is inconsistent; for instance, the sample from iteration 20,000 might be of high quality, whereas the one from 21,000 could be low quality, and the sample from 22,000 might be high quality again. Furthermore, using the model weights from iterations that generated high-quality samples results in significant variability during testing; some results are excellent while others are barely better than noise. This fluctuation in quality begins around 20,000 iterations and persists until the end of the training period. I have experimented with both unconditional and conditional training approaches, but both exhibited similar patterns. Have you encountered a similar situation before, and do you have any potential solutions?

Thanks for your time, I really appreciate it.

mobaidoctor commented 4 months ago

Hi, thank you for your inquiry and interest in our work. We apologize for the delay in responding. Due to an ongoing intensive project, we are fully occupied with daily tasks and unable to address other matters at this time. This situation may continue until the end of June. We will do our best to respond promptly here, but please allow 3-5 days for replies to future inquiries. Thank you for your understanding.

In our experiments, we prepared our training dataset by excluding all distorted and low-quality images, retaining only high-quality ones. As a result, our model achieved fast convergence and performed well with 250 timesteps. However, in another experiment involving MRI to CT translation, where the dataset comprised low-quality images, the model struggled to converge quickly. During training, it frequently produced noisy images within the first 500,000 iterations, indicating that the model had not yet converged. Subsequently, we trained it for almost 2,000,000 iterations with three different learning rates. With such a dataset of low-quality images, it might require more than 1,000,000 iterations to converge.

We are currently developing the next version of our proposed method, Med-DDPM-v2.0, to improve our model's convergence capability. We aim to introduce it in the literature this coming summer. In the meantime, if you have any further questions or need clarification, please don't hesitate to reach out. Thank you.

jngw21 commented 4 months ago

Thank you for your feedback.

  1. Could you please explain how you determine when the model has converged? I've noticed that the training loss decreases rapidly during the initial phase of training but then continues to oscillate around a small value. I couldn't find a criteria to confirm convergence.
  2. Additionally, how do you adjust the learning rate during the training process?
  3. The training iterations you used for the 3D-DDPM were 100,000, but the iterations you mentioned in the response were 500,000, 2,000,000, and 1,000,000. This is significantly higher—more than 10 times the number of iterations used for the 3D-DDPM. Could there be a mistake or an error in these figures? Or they do need 10 times extra time to converge. Thanks for your help in advance. I really appreciate it.
mobaidoctor commented 4 months ago

@jngw21 Hi, we evaluated our model's convergence by sampling images besides examining their training curve. If the model consistently generates at least 5 noise-free images during sequential sampling steps in the first stage training, we consider it the best model for fine-tuning. We then proceed to train this model further at a lower learning rate. Regarding your question about adjusting the learning rate: we start training with a higher learning rate and then fine-tune the best model with a lower learning rate. For our 3D-DDPM, we used the same dataset as our Med-DDPM, but applied brain extraction and trained with the brain-extracted dataset. In your case, we highly recommend cleaning your dataset to include only high-quality images without distortion or artifacts. Training with high-quality images can help your model converge faster. However, we acknowledge that our current method is sensitive to dataset quality. If the dataset is diverse and contains low-quality images, the convergence time significantly increases. We are currently working on improving our method to make it robust and less dependent on dataset quality. Thank you.