naver-ai / StyleMapGAN

Official pytorch implementation of StyleMapGAN (CVPR 2021)
https://www.youtube.com/watch?v=qCapNyRA_Ng
Other
458 stars 81 forks source link

RuntimeError: Trying to backward through the graph a second time #16

Open mingo-x opened 3 years ago

mingo-x commented 3 years ago

Hi there, thanks so much for sharing the codes - really amazing work!

When I was trying to train the model on my own, I ran into an error at line 383 in train.py ((w_rec_loss * args.lambda_w_rec_loss).backward()) that

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I read the model definition and the implementation looked correct to me - so I don't understand why such error was thrown. Do you maybe have any idea what could have gone wrong?

In the meanwhile, to unblock myself, I modified the codes a bit to run backward() for g_loss and w_rec_loss in one go (g_w_rec_loss in the following example). Does this modification make sense to you? Why did you separate the backward operation in the first place?

        adv_loss, w_rec_loss, stylecode = model(None, "G")
        adv_loss = adv_loss.mean()
        w_rec_loss = w_rec_loss.mean()
        g_loss = adv_loss * args.lambda_adv_loss

        g_optim.zero_grad()
        e_optim.zero_grad()
        g_w_rec_loss = g_loss + w_rec_loss * args.lambda_w_rec_loss
        g_w_rec_loss.backward()
        gather_grad(
            g_module.parameters(), world_size
        )  # Explicitly synchronize Generator parameters. There is a gradient sync bug in G.
        g_optim.step()
        e_optim.step()

Thanks in advance for your help!

blandocs commented 3 years ago

Hi mingo-x, you can combine two losses (g_loss, w_rec_loss) together. I think there is no huge difference.

However, you should be aware that w_rec_loss only affects the encoder, not the generator in the original version. Your modification makes w_rec_loss also affect the update of the generator.

Lastly, if you didn't modify any training code, I don't know why RuntimeError occurs. Please check your torch version.

songquanpeng commented 2 years ago

I also ran into this error. Hi @mingo-x, have you solved it?

songquanpeng commented 2 years ago

PyTorch 1.8 is working for me, version 1.10 is not.

vicentowang commented 2 years ago

@songquanpeng backward(retain_graph=True), solved it.