nannau / DoWnGAN

PyTorch/MLflow implementation of Wasserstein Generative Adversarial Networks with Gradient Penalty (WGAN-GP) to perform single image super resolution (SISR) to downscale climate fields.
GNU General Public License v3.0
4 stars 1 forks source link

Investigate batch size issue in gradient penalty #17

Open nannau opened 1 year ago

nannau commented 1 year ago

There is a perplexing batch_size issue when taking the gradient penalty with batch size 48, 50 (and I assume others).

nannau commented 1 year ago

See the following error message:

Loading region into memory...

Coarse data shape: torch.Size([15704, 7, 16, 16]) Fine data shape:  torch.Size([15704, 2, 128, 128]) Network dimensions:

Fine:  128 x 2

Coarse:  16 x 7

Enter the experiment name you wish to add the preceding training run to.

Select number from list or press n for new experiment:

0 : Initial test

1 : Batch size = 64, alpha = 500

Input number here: n

Enter new descriptive experiment name Batch size = 48, alpha = 1 You entered Batch size = 48, alpha = 1. Happy? (Y/n) Y Describe the specifics and purpose of this training run:

You entered . Happy? (Y/n)

Tracking URI: /home/acannon/data/mlflow_experiments

================================================================================

Traceback (most recent call last):

 File "/home/acannon/DoWnGAN/DoWnGAN/GAN/train.py", line 38, in <module>

   train()

 File "/home/acannon/DoWnGAN/DoWnGAN/GAN/train.py", line 28, in train

   trainer.train(

 File "/home/acannon/DoWnGAN/DoWnGAN/GAN/wasserstein.py", line 189, in train

  self._train_epoch(dataloader, testdataloader, epoch)

 File "/home/acannon/DoWnGAN/DoWnGAN/GAN/wasserstein.py", line 134, in _train_epoch

  self._critic_train_iteration(coarse, fine)

 File "/home/acannon/DoWnGAN/DoWnGAN/GAN/wasserstein.py", line 40, in _critic_train_iteration

  gradient_penalty = hp.gp_lambda * self._gp(fine, fake, self.C)

 File "/home/acannon/DoWnGAN/DoWnGAN/GAN/wasserstein.py", line 110, in _gp

   gradients = gradients.view(hp.batch_size, -1).to(config.device)

RuntimeError: shape '[48, -1]' is invalid for input of size 262144

Offending line is here https://github.com/nannau/DoWnGAN/blob/7629bd401b52605ec6af19d518196cec2ac0a704/DoWnGAN/GAN/wasserstein.py#L110

I think it might have something to do with torch.Tensor.view having incompatible shapes. Note that one infers from hp.batch_size and the gradients infers from the current batch size of the tensor _gp receives https://github.com/nannau/DoWnGAN/blob/7629bd401b52605ec6af19d518196cec2ac0a704/DoWnGAN/GAN/wasserstein.py#L88

nannau commented 1 year ago

See https://pytorch.org/docs/stable/generated/torch.Tensor.view.html for more information. My guess is that for some reason, view() is compatible with the other batch sizes, but not with some. I'm not quite sure the conditions for why that's true though.