pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.36k stars 9.53k forks source link

Potential speedup for DCGAN #170

Closed harveyslash closed 7 years ago

harveyslash commented 7 years ago

In the dcgan example, while training the discriminator, why is backward called twice ? First its called on the real images, then the fake images. Instead, shouldn't doing something like: totalError = real_loss + fake_loss , and then calling totalError.backward() save one whole backprop ? Does doing it the way i suggested change anything qualitatively ?

soumith commented 7 years ago

No, this will still call backward of the two separate graphs separately.

harveyslash commented 7 years ago

isn't the loss of real and the fake loss both losses of the discriminator ? how does the 2nd graph (generator )come into play here?

Kaixhin commented 7 years ago

I quickly drew up the computation graph to show why this already as efficient as possible. First, D takes a real image and calculates one forward/backward pass on that (D is trained to classify this as real). Then, D takes a fake image (which is "detached" from G to prevent gradients going back in to G) and calculates one forward/backward pass on that (D is trained to classify this as fake). Then we update the D using these gradients. So whether or not you add these first two losses, they come from completely different forward passes, so as Soumith said, there's completely different backward passes to be made and there is no computation to be saved.

Finally, we take the fake image (still attached to G) and calculates one forward/backward pass on that (D is trained to classify this as real). Only G is updated with these gradients.

gan

baldassarreFe commented 7 years ago

Hi @soumith, thanks for the dcgan example!

I have a couple of questions regarding line 239:

errD = errD_real + errD_fake
Kaixhin commented 7 years ago

Why do we sum the losses instead of taking the mean?

It is possible to do what you suggested, but summing more accurately matches the discriminator update step in the literature (see Algorithm 1 in the original paper). The discriminator is trained to recognise real data as real and fake data as fake.

Why do we operate on the Variables

Extracting .data and creating a new Variable breaks the computation graph, so you wouldn't be able to backpropagate through the networks. The general advice is not to over-optimise for memory. You may want to read this for a reasonable introduction to the graph in PyTorch.

baldassarreFe commented 7 years ago

The discriminator is trained to recognize real data as real and fake data as fake.

Yes, I see that. In fact, the discriminator is trained using the gradient from two losses:

It is possible to do what you suggested, but summing more accurately matches the discriminator update step in the literature

I went back to check the paper and I see what you mean, the code reflects what the authors do: 1 / #batch * (loss_on_#batch_real_images + loss_on_#batch_fakes) I'm actually surprised because I'd have intuitively done this: 1 / (2 * #batch) * (loss_on_#batch_real_images + loss_on_#batch_fakes) Do you happen to have an explanation for their choice or some resource to point me to?
Anyway, the result is basically the same, the only difference is that the computed gradient is doubled, but with an adaptive learning rate this does not affect training.

Extracting .data and creating a new Variable breaks the computation graph, so you wouldn't be able to backpropagate through the networks

Yes, I understand how the graph works. However, at line 239 the two losses have already been computed and the two gradients deriving from their backprop have already been accumulated in the weights. At that point, we don't need the sum of the losses to be part of the graph, in fact we never call errD.backward(), and we are only interested in its value. Anyway, I'm positive that creating a new Variable does not add too much overhead compared to a Tensor, even if I don't know the implementation details.

Kaixhin commented 7 years ago

Exactly, it doesn't really make a difference, so at this point you'd probably have to ask the original authors.

Sorry, just assumed that it was added for a reason like backpropagation, but in this case yes you could just start working with tensors instead of Variables.

apaszke commented 7 years ago

Still, Variable overhead is fairly small and we're aggressively reducing it. Removing one addition from Variable-land will have absolutely no effect on the run time.