Closed RansSelected closed 4 years ago
Hi Kris,
Thanks for pointing out the issue.
The code runs well on my computer. My suspicion is the PyTorch different versions. I am using PyTorch 1.4.0.
If it doesn't work, please let me know and I will further think.
Update: Everything works smoothly with the earlier version of PyTorch:
https://pytorch.org/get-started/previous-versions/
pytorch==1.2.0 torchvision==0.4.0
Hi. ZhenyueQin.
I change this code of solver_gan.py, in line 379.
before
if train_val_test == 'train':
self.reset_grad()
# Optimise generator.
if cur_step % self.n_critic == 0:
train_step_G.backward(retain_graph=True)
self.g_optimizer.step()
# Optimise value network.
if cur_step % self.n_critic == 0:
train_step_V.backward()
self.v_optimizer.step()
after
if train_val_test == 'train':
self.reset_grad()
# Optimise generator.
if cur_step % self.n_critic == 0:
train_step_G.backward(retain_graph=True)
train_step_V.backward()
self.g_optimizer.step()
self.v_optimizer.step()
# Optimise value network.
# if cur_step % self.n_critic == 0:
# train_step_V.backward()
# self.v_optimizer.step()
@krk-krk-krk I changed it as well and it works for me.
Thanks @krk-krk-krk This worked for me. No need to downgrade pytorch version.
Hello Zhenyue Qin,
Thank you for your implementation! I'm trying to run the code locally and I have some issues. While running both, here what I have:
So, basically in the file solver_gan.py, in line 379, where you do:
and a problem arises in train_step_V.backward()....
My environment has such specs:
Maybe you have any ideas about what goes wrong?
All best regards, Kris