Open Ciccios96 opened 3 years ago
The same issue with the version 0.8.2 of torchvision (it works properly with the version 0.4.2).
The problem arises because grid
variable is not a normal torch.Tensor
, but rather a tweaked Distribution
variable. To fix the issue, I had to tweak lines:
in train_fns.py
:
from
torchvision.utils.save_image(fixed_Gz.float().cpu(), image_filename,
to
torchvision.utils.save_image(torch.from_numpy(fixed_Gz.float().cpu().numpy()), image_filename,
in utils.py
:
add line:
out_ims = torch.from_numpy(out_ims.numpy())
after line:
out_ims = torch.stack(ims, 1).view(-1, ims[0].shape[1], ims[0].shape[2],
ims[0].shape[3]).data.float().cpu()
What we basically do here is converting the tweaked Distribution
tensor into normal torch.Tensor
via torch.from_numpy(t.numpy())
.
Also, JFYI, here are the scores I got for cifar-10 when launching the launch_cifar_ema.sh
script:
Hello, when i try to run the torchvision.utils.save_image() function inside the save_and_sample() function i get the error in the title. Can i get some help? This error is present in the Cifar10 train script.