yashbonde / dall-e-baby

When Dall E was a baby trained on a bit of data
MIT License
25 stars 3 forks source link

Error detected in SoftmaxBackward #1

Closed robvanvolt closed 3 years ago

robvanvolt commented 3 years ago

Really cool project that I tried to recreate. Unfortunately, after testing, the discrete_vae.py gives the following error because of on inplace operation I cannot find:

Model is now CUDA!
[TRAIN - 0] GS: 1, Loss: 1.19789:   0%|▊                                                                                                                                                        | 1/202 [00:06<21:26,  6.40s/it]:: Entering Testing Mode
[TEST - 0]:   0%|                                                                                                                                                                                         | 0/1 [00:02<?, ?it/s]
:::: Loss: 1.2889413833618164                                                                                                                                                                             | 0/1 [00:01<?, ?it/s]
[TRAIN - 0] GS: 2, Loss: 1.2198:   1%|█▌                                                                                                                                                        | 2/202 [00:10<16:54,  5.07s/it][W ..\torch\csrc\autograd\python_anomaly_mode.cpp:104] Warning: Error detected in SoftmaxBackward. Traceback of forward call that caused the error:
    trainer.train(
    _, loss, _=model(d)
    result = self.forward(*input, **kwargs)
    return self.module(*inputs[0], **kwargs[0])
    result = self.forward(*input, **kwargs)
    softmax = F.softmax(encoding, dim = 1)
    ret = input.softmax(dim)
 (function _print_stack)
[TRAIN - 0] GS: 2, Loss: 1.2198:   1%|█▌                                                                                                                                                        | 2/202 [00:11<18:45,  5.63s/it] 
Traceback (most recent call last):
    trainer.train(
    loss.backward()
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
    Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 3000, 16, 16]], which is output 0 of SoftmaxBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Could you give me a hint, where the inplace operation might be? Running this code on windows on python 3.8.7 64-bit if it helps!

yashbonde commented 3 years ago

Thanks for your interest in this. You can find the inplace operation in LeakyRelu in Residual block, line. The code looks like this:

self.resblock = nn.Sequential(
    nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
    nn.BatchNorm2d(out_channels),
    nn.LeakyReLU(inplace=True),
    nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
    nn.BatchNorm2d(out_channels),
)

In all the latest experiments I have not used the Residual block because it did not add any specific improvement that can be directly attributed to that. More information on inplace=True is this discussion.

Hope this helps.

robvanvolt commented 3 years ago

Thank you for the quick response!

I tried "add_residual = False" and commenting out the complete ResidualLayer, but still the same error message appears.

Unfortunately, the SoftmaxBackward is still expecting version 0 as above.

Do you have another idea to solve the problem?

yashbonde commented 3 years ago

Since this is talking about a tensor with shape [64, 3000, 16, 16] it must be referring to the quantised image. So this must be the forward() method in class VQVAE_v3. here

robvanvolt commented 3 years ago

Thank you very much, I was finally able to find the inplace operation:

softmax = softmax.scatter(1, torch.argmax(softmax, dim = 1).unsqueeze(1), 1)

instead of

softmax = softmax.scatter_(1, torch.argmax(softmax, dim = 1).unsqueeze(1), 1)

fixes the inplace warning! :-)