TobiasSunderdiek / cartoon-gan

Implementation of cartoon GAN [Chen et al., CVPR18] with pytorch
https://tobiassunderdiek.github.io/cartoon-gan/
MIT License
61 stars 22 forks source link

Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same #5

Closed bdlneto closed 2 years ago

bdlneto commented 3 years ago

When running the Google Collab trainning on the GPU, this error is presented, apparently, it considers that de Data is on the CPU, however, the model is on the GPU. I could not find a way to solve this problem.

TobiasSunderdiek commented 2 years ago

Hi, first of all: thank you for filling this issue and sorry for the late reply! It may be the case if the vgg16 model is loaded for the first time, I think I missed a to(device) in the VGG-16 cell:

except FileNotFoundError: vgg16 = models.vgg16(pretrained=True) vgg16 = vgg16.to(device)

It's a long time since you opened this issue, but maybe it still helps?

Best regards!

bdlneto commented 2 years ago

Hi!

I deleted the original vgg16 weight and downloaded it again with this new code, now the Trainning is working!

Thank You!

TobiasSunderdiek commented 2 years ago

Hi, great, thank you for your reply and trying this idea out!

Thank you for help and with that making this repo better!

Best regards!

[Edit:] Notebook updated with fixed code and credits