YixinChen-AI / CVAE-GAN-zoos-PyTorch-Beginner

For beginner, this will be the best start for VAEs, GANs, and CVAE-GAN. This contains AE, DAE, VAE, GAN, CGAN, DCGAN, WGAN, WGAN-GP, VAE-GAN, CVAE-GAN. All use PyTorch.
687 stars 104 forks source link

VAE-GAN 代码直接运行报错 #10

Open lmqhello opened 8 months ago

lmqhello commented 8 months ago

我的环境是torch==1.8.1,跑VAE例程或者其他例程直接可以跑,暂时没遇到问题,就是跑这个例程报错。

直接运行你的代码,报错如下:


RuntimeError Traceback (most recent call last) Cell In[9], line 204 202 output = D(recon_data) 203 errVAE = criterion(output, real_label) --> 204 errVAE.backward() 205 D_G_z2 = output.mean().item() 206 optimizerVAE.step()

File d:\ProgramData\Anaconda3\envs\pt\lib\site-packages\torch\tensor.py:245, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs) 236 if has_torch_function_unary(self): 237 return handle_torch_function( 238 Tensor.backward, 239 (self,), (...) 243 create_graph=create_graph, 244 inputs=inputs) --> 245 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)

File d:\ProgramData\Anaconda3\envs\pt\lib\site-packages\torch\autograd__init__.py:145, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs) 142 if retain_graph is None: 143 retain_graph = create_graph --> 145 Variable._execution_engine.run_backward( 146 tensors, gradtensors, retain_graph, create_graph, inputs, 147 allow_unreachable=True, accumulate_grad=True)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [16, 1, 4, 4]] is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

GRicciardi00 commented 8 months ago

I'm having the same problem

GRicciardi00 commented 8 months ago

I solved it by changing line 202 in the training loop.

output = D(recon_data.detach()) Here, recon_data is the output of the VAE's decoder, and detaching it from the computation graph is done to prevent gradients from being computed with respect to recon_data.

lmqhello commented 8 months ago

I solved it by changing line 202 in the training loop.我通过在训练循环中更改第 202 行来解决它。

output = D(recon_data.detach())输出 = D(recon_data.detach()) Here, recon_data is the output of the VAE's decoder, and detaching it from the computation graph is done to prevent gradients from being computed with respect to recon_data.这里,recon_data是 VAE 解码器的输出,将其与计算图分离以防止计算相对于recon_data的梯度。

Thank you very much