Open uyo9ko opened 1 year ago
change
real_fake_images = torch.cat((imgs[:4], decoded_images.add(1).mul(0.5)[:4]))
to
real_fake_images = torch.cat((imgs.add(1).mul(0.5)[:4], decoded_images.add(1).mul(0.5)[:4]))
Hello, I have also been researching vqgan recently. May I ask if your network has good convergence performance.
in the train_vqgan.py, the code
real_fake_images = torch.cat((imgs[:4], decoded_images.add(1).mul(0.5)[:4]))
should be revisedreal_fake_images = torch.cat((imgs.add(1).mul(0.5)[:4], decoded_images.add(1).mul(0.5)[:4]))