jorge-pessoa / pytorch-msssim

PyTorch differentiable Multi-Scale Structural Similarity (MS-SSIM) loss
Other
453 stars 70 forks source link

Expected object of scalar type Double but got scalar type Float for argument #2 'weight' #6

Closed alinajadebarnett closed 5 years ago

alinajadebarnett commented 5 years ago

So I implemented this as follows:

import pytorch_msssim [...] lr_loss = pytorch_msssim.MSSSIM() [...] lr_tensor = torch.tensor(np.expand_dims(lr_img.astype(np.float32), axis=0)).type('torch.DoubleTensor').to(DEVICE) in_tensor = torch.tensor(np.expand_dims(sr_img.astype(np.float32), axis=0)).type('torch.DoubleTensor').to(DEVICE) [...] ds_in_tensor = bds(in_tensor, nhwc=True) lr_l = lr_loss(ds_in_tensor, lr_tensor) l2_l = l2_loss(in_tensor, org_tensor) l = lr_l + LAMBDA * l2_l l.backward()

And I'm getting this error:

Traceback (most recent call last): File "/usr/xtmp/superresoluter/superresolution/tester_msssim.py", line 137, in lr_l = lr_loss(ds_in_tensor, lr_tensor) File "/home/home5/abarnett/sr/lib/python3.5/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, **kwargs) File "/usr/project/xtmp/superresoluter/superresolution/pytorch_msssim/init.py", line 133, in forward return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) File "/usr/project/xtmp/superresoluter/superresolution/pytorch_msssim/init.py", line 78, in msssim sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) File "/usr/project/xtmp/superresoluter/superresolution/pytorch_msssim/init.py", line 41, in ssim mu1 = F.conv2d(img1, window, padding=padd, groups=channel) RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 'weight'

Any ideas?

alinajadebarnett commented 5 years ago

This was resolve by changing 'torch.DoubleTensor' to 'torch.FloatTensor' in all locations. Example below.

lr_tensor = torch.tensor(np.expand_dims(lr_img.astype(np.float32), axis=0)).type('torch.DoubleTensor').to(DEVICE) lr_tensor = torch.tensor(np.expand_dims(lr_img.astype(np.float32), axis=0)).type('torch.FloatTensor').to(DEVICE)

SumanthMeenan commented 5 years ago

I also faced the same problem,So, pytorch needs tensors of dtype float for all inputs and parameters(weights and biases) . Applying .float() on input tensor resolved the issue