DmitryUlyanov / deep-image-prior

Image restoration with neural networks but without learning.
https://dmitryulyanov.github.io/deep_image_prior
Other
7.9k stars 1.43k forks source link

tensor concatenate error in sr_prior_effect notebook #29

Open prash-p opened 6 years ago

prash-p commented 6 years ago

For me the very last line of the sr_prior_effect notebook fails, but I'm not sure why. It gives an incorrect type error even through they are the same type (I think...)

plot_image_grid([imgs['HR_np'],
                 result_no_prior,
                 result_tv_prior,
                 result_deep_prior], factor=8, nrow=2, interpolation='lanczos')

-----------------------------------------------------------
TypeError                 Traceback (most recent call last)
<ipython-input-42-8fd807e1bf2e> in <module>()
      2                  result_no_prior,
      3                  result_tv_prior,
----> 4                  result_deep_prior], factor=8, nrow=2, interpolation='lanczos')

~/DeepImagePrior/deep-image-prior/utils/common_utils.py in plot_image_grid(images_np, nrow, factor, interpolation)
     75     images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]
     76 
---> 77     grid = get_image_grid(images_np, nrow)
     78 
     79     plt.figure(figsize=(len(images_np)+factor,12+factor))

~/DeepImagePrior/deep-image-prior/utils/common_utils.py in get_image_grid(images_np, nrow)
     57     '''Creates a grid from a list of images by concatenating them.'''
     58     images_torch = [torch.from_numpy(x) for x in images_np]
---> 59     torch_grid = torchvision.utils.make_grid(images_torch, nrow)
     60 
     61     return torch_grid.numpy()

~/anaconda2/envs/py36/lib/python3.6/site-packages/torchvision-0.2.0-py3.6.egg/torchvision/utils.py in make_grid(tensor, nrow, padding, normalize, range, scale_each, pad_value)
     33     # if list of tensors, convert to a 4D mini-batch Tensor
     34     if isinstance(tensor, list):
---> 35         tensor = torch.stack(tensor, dim=0)
     36 
     37     if tensor.dim() == 2:  # single image H x W

~/anaconda2/envs/py36/lib/python3.6/site-packages/torch/functional.py in stack(sequence, dim, out)
     62     inputs = [t.unsqueeze(dim) for t in sequence]
     63     if out is None:
---> 64         return torch.cat(inputs, dim)
     65     else:
     66         return torch.cat(inputs, dim, out=out)

TypeError: cat received an invalid combination of arguments - got (list, int), but expected one of:
 * (sequence[torch.FloatTensor] seq)
 * (sequence[torch.FloatTensor] seq, int dim)
      didn't match because some of the arguments have invalid types: (list, int)
prash-p commented 6 years ago

Seems like torch doesn't like ndarrays, and expects torch tensors. Correct usage is:

plot_image_grid([imgs['HR_np'],
                 out_HR_noprior_np,
                 out_HR_TV_np,
                 out_HR_deep_np], factor=8, nrow=2, interpolation='lanczos');

Where out_X_np = np.clip(var_to_np(net(net_input)), 0, 1) after each loss experiment in the notebook