pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.21k stars 9.51k forks source link

DCGAN C++ warning after PyTorch update #819

Open UltraCoderRU opened 4 years ago

UltraCoderRU commented 4 years ago

After PyTorch update I get following message while running sample DCGAN C++ application:

[W Resize.cpp:19] Warning: An output with one or more elements was resized since it had shape [64, 1, 1, 1], which does not match the required output shape [64, 1, 1, 64].This behavior is deprecated, and in a future PyTorch release outputs will not be resized unless they have zero elements. You can explicitly reuse an out tensor t by resizing it, inplace, to zero elements with t.resize_(0). (function resize_output)

Message goes from calls to Tensor::backward(). It seems like this warning is a result of changes in 2f840b166. I don't know if it's a PyTorch or sample application issue.

Neathle commented 3 years ago

I seem to experience the exact same error, Is there a way to solve this problem?

liubamboo commented 3 years ago

I have the same warning too

phoenixmy commented 3 years ago

It seems like something related with the value of "kBatchSize". According to the warning, the output expect [64,1,1,64], the last number 64 should stand for the batch size. I changed the value of kBatchSize from 64 to 1, and the warning disappeared. But I am not familiar with ATen, could someone tell me the root cause?

a-pushkin commented 3 years ago

Setting the batch size to 1 will work, however it will significantly slow down GPU training. Reshaping network outputs (which get a bogus second dimension of size one) just so allows training with batches larger than one:

    for (torch::data::Example<>& batch : *data_loader) {
      const auto batch_size = batch.data.size(0);
      // Train discriminator with real images.
      discriminator->zero_grad();
      torch::Tensor real_images = batch.data.to(device);
      torch::Tensor real_labels =
          torch::empty(batch_size, device).uniform_(0.8, 1.0);
      torch::Tensor real_output =
          discriminator->forward(real_images).reshape({batch_size});

      torch::Tensor d_loss_real =
          torch::binary_cross_entropy(real_output, real_labels);
      d_loss_real.backward();

      // Train discriminator with fake images.
      torch::Tensor noise =
          torch::randn({batch.data.size(0), kNoiseSize, 1, 1}, device);
      torch::Tensor fake_images = generator->forward(noise);
      torch::Tensor fake_labels = torch::zeros(batch.data.size(0), device);
      torch::Tensor fake_output =
          discriminator->forward(fake_images.detach()).reshape({batch_size});
      torch::Tensor d_loss_fake =
          torch::binary_cross_entropy(fake_output, fake_labels);
      d_loss_fake.backward();

      torch::Tensor d_loss = d_loss_real + d_loss_fake;
      discriminator_optimizer.step();

      // Train generator.
      generator->zero_grad();
      fake_labels.fill_(1);
      fake_output = discriminator->forward(fake_images).reshape({batch_size});
      torch::Tensor g_loss =
          torch::binary_cross_entropy(fake_output, fake_labels);
      g_loss.backward();
      generator_optimizer.step();
      batch_index++;
      if (batch_index % kLogInterval == 0) {
        std::printf("\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f\n",
                    epoch, kNumberOfEpochs, batch_index, batches_per_epoch,
                    d_loss.item<float>(), g_loss.item<float>());
      }

Call .reshape({batch_size}) on results of the forward calls.

fulltopic commented 3 years ago

@a-pushkin Reshape works, thank you.

reshape(fake_labels.sizes()) would be better in case that batch.data.size(0) != batch_size

msaroufim commented 2 years ago

@a-pushkin would you like to make a PR with your fix?