thank you for your great research and release of the source code!
When trying to reproduce the results, I encounter problems while training the X-Ray VQ-VAE using the configuration default_xray_vqgan_config.py. I pre-processed the LIDC-IDRI dataset according to the CT pre-processing of the x2ct-gan repository (without removing the bed, as the masks are not publicly available) and generated 2D projections using the instructions from the MedNeRF repository, as stated in your ReadMe.
After starting the training with the command python3 train_xray_vqgan.py --config configs/default_xray_vqgan_config.py, everything works fine until reaching the step 30001, where the discriminator starts to train. I immediately get the following error:
/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/autograd/__init__.py:197: UserWarning: Error
detected in ConvolutionBackward0. Traceback of forward call that caused the error:
File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 229, in <module>
app.run(main)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 226, in main
train(H, vqgan, vqgan_ema, train_loader, test_loader, optim, d_optim, start_step, scaler=scaler, d_scaler=d_scaler, **train_kwargs)
File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 69, in train
x_hat, stats = vqgan.train_iter_together(x, global_step, scaler=scaler)
File "/home/user/x2ct-vqvae/models/vqgan_2d.py", line 73, in train_iter_together
logits_fake = self.disc(x_hat)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1212, in _call_impl
result = forward_call(*input, **kwargs)
File "/home/user/x2ct-vqvae/models/ada_2d.py", line 78, in forward
pred: torch.Tensor = self.discriminator(images)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/user/x2ct-vqvae/models/vqgan_2d.py", line 404, in forward
return self.main(x)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 463, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
return F.conv2d(input, weight, bias, self.stride,
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/fx/traceback.py", line 57, in format_stack
return traceback.format_stack()
(Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.)
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 229, in <module>
app.run(main)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 226, in main
train(H, vqgan, vqgan_ema, train_loader, test_loader, optim, d_optim, start_step, scaler=scaler, d_scaler=d_scaler, **train_kwargs)
File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 73, in train
update_model_weights(optim, stats['loss'], amp=H.train.amp, scaler=scaler)
File "/home/user/x2ct-vqvae/./train_xray_vqgan.py", line 39, in update_model_weights
scaler.scale(loss).backward(retain_graph=retain_graph)
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/home/user/miniconda3/envs/x2ct/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Function 'ConvolutionBackward0' returned nan values in its 0th output.
I set up the environment with python 3.9.18 and pytorch 1.13.1+cu117 on an NVIDIA A100.
The error persists when re-trying training with a different random state. From what I understand, the gradient calculation just overflows due to the gradient scaling. Did you use the config default_xray_vqgan_config.py as it is given to produce your published results?
Hi! thank you for your interest in our work. Apologies, I did forgot there were some changes to the config files. I have pushed the changes. I'll upload the weights of our models shortly too ;)
Hi,
thank you for your great research and release of the source code!
When trying to reproduce the results, I encounter problems while training the X-Ray VQ-VAE using the configuration
default_xray_vqgan_config.py
. I pre-processed the LIDC-IDRI dataset according to the CT pre-processing of the x2ct-gan repository (without removing the bed, as the masks are not publicly available) and generated 2D projections using the instructions from the MedNeRF repository, as stated in yourReadMe
.After starting the training with the command
python3 train_xray_vqgan.py --config configs/default_xray_vqgan_config.py
, everything works fine until reaching the step 30001, where the discriminator starts to train. I immediately get the following error:I set up the environment with python 3.9.18 and pytorch 1.13.1+cu117 on an NVIDIA A100.
The error persists when re-trying training with a different random state. From what I understand, the gradient calculation just overflows due to the gradient scaling. Did you use the config
default_xray_vqgan_config.py
as it is given to produce your published results?