samb-t / x2ct-vqvae

14 stars 1 forks source link

Training X-Ray VQ-VAE crashes #1

Closed max-3l closed 10 months ago

max-3l commented 10 months ago

Hi,

thank you for your great research and release of the source code!

When trying to reproduce the results, I encounter problems while training the X-Ray VQ-VAE using the configuration default_xray_vqgan_config.py. I pre-processed the LIDC-IDRI dataset according to the CT pre-processing of the x2ct-gan repository (without removing the bed, as the masks are not publicly available) and generated 2D projections using the instructions from the MedNeRF repository, as stated in your ReadMe.

After starting the training with the command python3 train_xray_vqgan.py --config configs/default_xray_vqgan_config.py, everything works fine until reaching the step 30001, where the discriminator starts to train. I immediately get the following error:

/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/autograd/__init__.py:197: UserWarning: Error
 detected in ConvolutionBackward0. Traceback of forward call that caused the error:
  File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 229, in <module>
    app.run(main)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 226, in main
    train(H, vqgan, vqgan_ema, train_loader, test_loader, optim, d_optim, start_step, scaler=scaler, d_scaler=d_scaler, **train_kwargs)
  File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 69, in train
    x_hat, stats = vqgan.train_iter_together(x, global_step, scaler=scaler)
  File "/home/user/x2ct-vqvae/models/vqgan_2d.py", line 73, in train_iter_together
    logits_fake = self.disc(x_hat)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/user/x2ct-vqvae/models/ada_2d.py", line 78, in forward
    pred: torch.Tensor = self.discriminator(images)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/x2ct-vqvae/models/vqgan_2d.py", line 404, in forward
    return self.main(x)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
    input = module(input)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/fx/traceback.py", line 57, in format_stack
    return traceback.format_stack()
 (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
  File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 229, in <module>
    app.run(main)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 226, in main
    train(H, vqgan, vqgan_ema, train_loader, test_loader, optim, d_optim, start_step, scaler=scaler, d_scaler=d_scaler, **train_kwargs)
  File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 73, in train
    update_model_weights(optim, stats['loss'], amp=H.train.amp, scaler=scaler)
  File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 39, in update_model_weights
    scaler.scale(loss).backward(retain_graph=retain_graph)
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'ConvolutionBackward0' returned nan values in its 0th output.

I set up the environment with python 3.9.18 and pytorch 1.13.1+cu117 on an NVIDIA A100.

The error persists when re-trying training with a different random state. From what I understand, the gradient calculation just overflows due to the gradient scaling. Did you use the config default_xray_vqgan_config.py as it is given to produce your published results?

abrilcf commented 10 months ago

Hi! thank you for your interest in our work. Apologies, I did forgot there were some changes to the config files. I have pushed the changes. I'll upload the weights of our models shortly too ;)