ZiJianZhao / SeqGAN-PyTorch

A implementation of SeqGAN in PyTorch, following the implementation in tensorflow.
260 stars 93 forks source link

in GAN training, rewards do not get reshaped correctly unless using cuda #7

Open bgenchel opened 6 years ago

bgenchel commented 6 years ago

rewards = rollout.get_reward(samples, 16, discriminator) rewards = Variable(torch.Tensor(rewards)) if args.cuda: rewards = torch.exp(rewards.cuda()).contiguous().view((-1,)) This is around lines 214-218, though maybe slightly different for you since I have edited the code. I'm pretty sure that the reshaping of the rewards variable needs to happen regardless of whether or not cuda is being used.