FirasGit / medicaldiffusion

Medical Diffusion: This repository contains the code to our paper Medical Diffusion: Denoising Diffusion Probabilistic Models for 3D Medical Image Synthesis
355 stars 62 forks source link

Nan issue #8

Open CaiwenXu opened 1 year ago

CaiwenXu commented 1 year ago

Hi, many thanks for your excellent work! I have a problem when training the VQ GAN, the loss will suddenly become nan, and do you know why this happens? I used the LIDC dataset.

benearnthof commented 1 year ago

I'm currently having the same problem I used the exact same configs provided here and still no luck, very unstable training. The Model does also suffer from mode collapse after the Discriminator starts training.

benearnthof commented 1 year ago

I believe this problem may stem from the accumulate_grad_batches parameter. I trained a run for more than 50000 steps successfully, but trying to replicate training with accumulate_grad_batches > 1 runs into the nan problem. @CWX-student can you confirm this or do you have any other info on your end?

benearnthof commented 1 year ago

Update: Using setting the precision parameter in the config to at least 32 seems to alleviate this problem. https://discuss.pytorch.org/t/distributed-training-gives-nan-loss-but-single-gpu-training-is-fine/63664/6