tamarott / SinGAN

Official pytorch implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image"
https://tamarott.github.io/SinGAN.htm
Other
3.29k stars 608 forks source link

Possible one-line solution for Runtime error (variables modified in-place) #148

Open williantrevizan opened 3 years ago

williantrevizan commented 3 years ago

Hi, thanks for the repository and this amazing work!

I opened this issue because it might provide a solution for the runtime error reported by cdjameson in another topic, that happens in newer versions of torch ('one of the variables needed for gradient computation has been modified by an inplace operation...'), that seems to be more straightfoward than the solution that Clefspear99 is proposing as a pool request.

The problem happens in the function train_single_scale() in training.py This function is composed basically of two sequential loops, one for optimizing the discriminator D, and the other for optimizing the generator G. At the end of the first loop, a fake image is generated by the generator. As soon as the second loop starts, this fake image is passed throught the discriminator, with generates a patch discrimination map, which is then used to calculate the loss errG. The command errG.backwards() calculates the gradients which are used for the optimization of netG weights via optimizerG.step(). The first time we go through this second loop everything runs smoothly and the optimizer changes netG weights inplace. However, the second time we go through this loop, the same fake image is used to calculate the loss (that is, the fake image that had been generated with a previous set of netG weights). Therefore, once we call the backwards function, the computational graph will point back to netG weights that were in their original version, before the optimization step. Newer versions of torch are able to catch this inconsistency and that seems to be the reason why the error occurs.

So, instead of downgrading torch, a simple solution would be to add the line,

fake = netG(noise.detach(), prev.detach())

right in the beggining of the second loop, to always recalculate the fake image with the correct weights.

tamarott, I think this might solve this problem. If you allow, I will submit a pull request with this modification.

tamarott commented 3 years ago

This is a possible solution, but pat attention that it changed the optimization process and therefore might change performances. So the results won't necessarily be identical to the original version.

williantrevizan commented 3 years ago

You are right, I'll pay atention to that! I ran a few tests with the application I'm working on, and it seems to be doing fine with this modification, but I didn't stress these tests too much.

About the optimization process, when I first thought about your paper and code, it made sense to me that conceptually the fake image should be recalculated at every step on that loop (for optimizing G). However what seems to be going on is that the adversarial loss is kept fixed (because you use the same fake image 3 times) and only the reconstruction loss is updated inside the loop. Is there a reason why that should work better?

tamarott commented 3 years ago

We found it to work better empirically. But other solutions might also work. Just be careful and make sure performances are the same.

williantrevizan commented 3 years ago

Nice, thanks a lot!!

ariel415el commented 2 years ago

Thanks @williantrevizan, Your fix worked for me

JasonBournePark commented 2 years ago

It works for me well too. You saved my time!! Thanks a lot!

WZLHQ commented 1 year ago

thanks. You realy save my time!

jethrolam commented 1 year ago

Thank you @williantrevizan! Confirmed that this solution works on torch==1.12.0.