pytorch / tutorials

PyTorch tutorials.
https://pytorch.org/tutorials/
BSD 3-Clause "New" or "Revised" License
8.09k stars 4.02k forks source link

SOLVED - IndexError: too many indices for tensor of dimension 0 from DCGAN tutorial #552

Closed seyonechithrananda closed 5 years ago

seyonechithrananda commented 5 years ago

IndexError: Traceback (most recent call last): File "/Applications/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/Applications/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 99, in <listcomp> samples = collate_fn([dataset[i] for i in batch_indices]) File "/Applications/anaconda3/lib/python3.7/site-packages/torchvision/datasets/folder.py", line 140, in __getitem__ sample = self.transform(sample) File "/Applications/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in __call__ img = t(img) File "/Applications/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 164, in __call__ return F.normalize(tensor, self.mean, self.std, self.inplace) File "/Applications/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 208, in normalize tensor.sub_(mean[:, None, None]).div_(std[:, None,None])` IndexError: too many indices for tensor of dimension 0

I get this error when trying to load images from the following path with the CelebA image dataset using ImageFolder: dataroot = "/Users/seyonechithrananda/dcgan/data" dataset = dset.ImageFolder(root=dataroot, transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize(0.5, 0.5, 0.5), (0.5, 0.5, 0.5) ])) dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle=True, num_workers = workers) device = torch.device("cpu") real_batch = next(iter(dataloader)) plt.figure(figsize = (8,8)) plt.axis('off') plt.title("Training Images") plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1 ,2,0)))

seyonechithrananda commented 5 years ago

This issue can be solve by changing the parameters for transforms.Normalize from (0.5, 0.5, 0.5), (0.5, 0.5, 0.5) to (0.5,), (0.5)

ghltshubh commented 4 years ago

@seyonechithrananda Can you please make an amendment from (0.5,), (0.5) to (0.5,), (0.5,)