Closed eromoe closed 6 years ago
Do you run it on example images or your own?
On example images. Same notebook without modification.
## Fig 7 (top)
img_path = 'data/inpainting/lena.png'
mask_path = 'data/inpainting/lena_mask.png'
NET_TYPE = 'skip_depth6' # one of skip_depth4|skip_depth2|UNET|ResNet
Appears to be the same issue as #12
Also getting this same error on Linux with:
pytorch 0.2.0.post3
Python 2.7.14
cuda 8.0
using the lena example image (didn't edit the original code)
RuntimeError Traceback (most recent call last)
in () 1 img_mask_var = np_to_var(img_mask_np).type(dtype) 2 ----> 3 plot_image_grid([img_np, img_mask_np, img_mask_np*img_np], 3,11); /home/phil/devel/ML/deep-image-prior/utils/common_utils.pyc in plot_image_grid(images_np, nrow, factor, interpolation) 70 interpolation: interpolation used in plt.imshow 71 """ ---> 72 grid = get_image_grid(images_np, nrow) 73 74 plt.figure(figsize=(len(images_np)+factor,12+factor)) /home/phil/devel/ML/deep-image-prior/utils/common_utils.pyc in get_image_grid(images_np, nrow) 57 '''Creates a grid from a list of images by concatenating them.''' 58 images_torch = [torch.from_numpy(x) for x in images_np] ---> 59 torch_grid = torchvision.utils.make_grid(images_torch, nrow) 60 61 return torch_grid.numpy() /home/phil/py-virt-env/env/local/lib/python2.7/site-packages/torchvision/utils.pyc in make_grid(tensor, nrow, padding, normalize, range, scale_each, pad_value) 33 # if list of tensors, convert to a 4D mini-batch Tensor 34 if isinstance(tensor, list): ---> 35 tensor = torch.stack(tensor, dim=0) 36 37 if tensor.dim() == 2: # single image H x W /home/phil/py-virt-env/env/local/lib/python2.7/site-packages/torch/functional.pyc in stack(sequence, dim, out) 62 inputs = [t.unsqueeze(dim) for t in sequence] 63 if out is None: ---> 64 return torch.cat(inputs, dim) 65 else: 66 return torch.cat(inputs, dim, out=out) RuntimeError: inconsistent tensor sizes at /pytorch/torch/lib/TH/generic/THTensorMath.c:2709
yes, it works now. Thanks for the quick response.
So it is due to lena is 3 channels but lena_mask is 1 chennel . Thank you @DmitryUlyanov .
env:
inpainting.ipynb
got this error atVisualize :
plot_image_grid([img_np, img_mask_np, img_mask_np*img_np], 3,11);