photosynthesis-team / piq

Measures and metrics for image2image tasks. PyTorch.
Apache License 2.0
1.38k stars 116 forks source link

MS-SSIM error when running on GPU #362

Closed pomelyu closed 1 year ago

pomelyu commented 1 year ago

Describe the bug The error raises when running MS-SSIM on GPU

  File "/home/chinyu.chien/.local/lib/python3.10/site-packages/piq/ms_ssim.py", line 190, in forward
    score = multi_scale_ssim(x=x, y=y, kernel_size=self.kernel_size, kernel_sigma=self.kernel_sigma,
  File "/home/chinyu.chien/.local/lib/python3.10/site-packages/piq/ms_ssim.py", line 78, in multi_scale_ssim
    msssim_val = _compute_msssim(
  File "/home/chinyu.chien/.local/lib/python3.10/site-packages/piq/ms_ssim.py", line 233, in _multi_scale_ssim
    msssim_val = torch.prod((mcs_ssim ** scale_weights.view(-1, 1, 1)), dim=0).mean(1)
  File "/root/miniconda/lib/python3.10/site-packages/torch/_tensor.py", line 32, in wrapped
    return f(*args, **kwargs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu 

To Reproduce

import piq
import torch

criterion = piq.MultiScaleSSIMLoss().cuda()
x = torch.rand(1, 3, 256, 256).cuda()
y = torch.rand(1, 3, 256, 256).cuda()

z = criterion(x, y)

Expected behavior No error message

Additional context It is due to the line163 in piq/ms_ssim.py

        if scale_weights is None:
            # Values from MS-SSIM paper
            self.scale_weights = torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])
        else:
            self.scale_weights = scale_weights

register_buffer should be used to setup the constant value

        if scale_weights is None:
            # Values from MS-SSIM paper
            self.register_buffer("scale_weights", torch.tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]))
        else:
            self.register_buffer("scale_weights", scale_weights)