corenel / torchzoo

Zoo of models and datasets for PyTorch.
MIT License
2 stars 3 forks source link

TypeError occur in using the usps.py #1

Open tryerrorman opened 6 years ago

tryerrorman commented 6 years ago

Thank you for your share. when I use the usps.py to load the data, a error like this occur:

Traceback (most recent call last): File "xdatatest.py", line 34, in data = dataloader.get_next_iter() File "/home/fenglei/codes/gan/xgan/data/newdata_loader.py", line 44, in get_next_iter dataB = self.dataLoaderB.iter().next() File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 202, in next return self._process_next_batch(batch) File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 222, in _process_next_batch raise batch.exc_type(batch.exc_msg) TypeError: Traceback (most recent call last): File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 41, in _worker_loop samples = collate_fn([dataset[i] for i in batch_indices]) File "/home/fenglei/codes/gan/xgan/data/usps.py", line 75, in getitem img = self.transform(img) File "/usr/local/lib/python2.7/dist-packages/torchvision/transforms.py", line 34, in call img = t(img) File "/usr/local/lib/python2.7/dist-packages/torchvision/transforms.py", line 199, in call return img.resize(self.size, self.interpolation) TypeError: an integer is required

corenel commented 6 years ago

Could you please give a minimal reproduction code that demonstrates this problem?

tryerrorman commented 6 years ago

Thank you for your reply@corenel. The following codes could produce the errors reported before.


from PIL import Image
import torchvision.transforms as transforms
import torch.utils.data 
from usps import USPS

transform_list = []
osize = [30,30]
transform_list.append(transforms.Scale(osize, Image.BICUBIC))
transform_list.append(transforms.RandomCrop(28))

transform_list.append(transforms.RandomHorizontalFlip())

transform_list += [transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5),
                                            (0.5, 0.5, 0.5))]
trans = transforms.Compose(transform_list)

dataset = USPS(root = '.', train=True, transform=trans,download=True)

dataLoader = torch.utils.data.DataLoader(dataset,batch_size = 4, shuffle=True,num_workers = 2)

data = dataLoader.__iter__().next()