MStypulkowski / BiGAN

Pytorch implementation of BiGAN and Conditional BiGAN
21 stars 8 forks source link

About a RuntimeError #1

Open YiningWang2 opened 3 years ago

YiningWang2 commented 3 years ago

Hello, when i ran the code, I met this error in the part of loss_EG.backward(), I really hope to solve it , can you help me? Best wishes, Cici.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024, 1]], which is output 0 of TBackward, is at version 6; expected version 5 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

DAYceng commented 2 years ago

Hello, when i ran the code, I met this error in the part of loss_EG.backward(), I really hope to solve it , can you help me? Best wishes, Cici.

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024, 1]], which is output 0 of TBackward, is at version 6; expected version 5 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I have the same problem, have you solved it?

Shockblack commented 2 years ago

Hello, when i ran the code, I met this error in the part of loss_EG.backward(), I really hope to solve it , can you help me? Best wishes, Cici. RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024, 1]], which is output 0 of TBackward, is at version 6; expected version 5 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I have the same problem, have you solved it?

Same issue here. If optimizer_D.step() is put between loss_EG.backward() and optimizer_EG.step() the code will run. This results in a mode collapse however, converging to black images after around 60 epochs.

Shockblack commented 2 years ago

I've found a way to fix this issue. It is adapted from an issue raised by @NoTargetFish. Basically the gradient must be cleaned before stepping through both the discriminator as well as the generator and encoder. An edit to the section of the code under the Training (using Conditional_BiGAN.ipynb as an example) can be made:

ORIGINAL

    #compute G(z, c) and E(X)
    Gz = G(z, c)
    EX = E(images)

    #compute D(G(z, c), z, c) and D(X, E(X), c)
    DG = D(Gz, z, c)
    DE = D(images, EX, c)

    #compute losses
    loss_D = D_loss(DG, DE)
    loss_EG = EG_loss(DG, DE)
    D_loss_acc += loss_D.item()
    EG_loss_acc += loss_EG.item()

    #Discriminator training
    optimizer_D.zero_grad()
    loss_D.backward(retain_graph=True)
    optimizer_D.step()

    #Encoder & Generator training
    optimizer_EG.zero_grad()
    loss_EG.backward()
    optimizer_EG.step()

MODIFIED

    # Start with Discriminator Training
    optimizer_D.zero_grad()

    #compute G(z, c) and E(X)
    Gz = G(z, c)
    EX = E(images)

    #compute D(G(z, c), z, c) and D(X, E(X), c)
    DG = D(Gz, z, c)
    DE = D(images, EX, c)

    #compute losses
    loss_D = D_loss(DG, DE)
    D_loss_acc += loss_D.item()

    loss_D.backward(retain_graph=True)
    optimizer_D.step()

    #Encoder & Generator training
    optimizer_EG.zero_grad()

    #compute G(z, c) and E(X)
    Gz = G(z, c)
    EX = E(images)

    #compute D(G(z, c), z, c) and D(X, E(X), c)
    DG = D(Gz, z, c)
    DE = D(images, EX, c)

    #compute losses
    loss_EG = EG_loss(DG, DE)
    EG_loss_acc += loss_EG.item()

    loss_EG.backward()
    optimizer_EG.step()

This should hopefully solve the issue. If you still get an error, replacing the in-place operators like EG_loss_acc += loss_EG.item() with EG_loss_acc = EG_loss_acc + loss_EG.item() could help, but I'm not sure if this is necessary. The error message just makes it seem like that.

I hope this helps anyone else who wants to use this repository!

NOTE: I am still new to PyTorch and ML so this might not be the most efficient or even correct way to solve the issue. It does however make the code run-able and produce realistic results.