pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.01k stars 6.93k forks source link

Concatenate torchvision.datasets.FakeData with another dataset -> cannot load it #3517

Closed zhifengkong closed 3 years ago

zhifengkong commented 3 years ago

🐛 Bug

If you concatenate a dataset such as CIFAR10 with FakeData, you get error

To Reproduce

Steps to reproduce the behavior:

  1. cifar_dataset = torchvision.datasets.CIFAR10(...)
  2. fake_dataset = torchvision.datasets.FakeData(...)
  3. train_data = Concat([cifar_dataset, fake_dataset])
  4. train_loader = DataLoader(train_data, ...)
  5. for data in train_loader then error

Additional context

The reason why it happens is the labels in CIFAR10 are int and labels in FakeData are tensors. When concatenating them to construct a batch, the batch labels look like [0,1,2,3,tensor(0),3,4,5,6,tensor(2)...].

I can solve this bug by letting target_transform=int when I load fake_dataset. However, this is very hard to debug. I assume that the default target type in the FakeData source code should be set to int instead of long tensor.

Here: https://pytorch.org/vision/0.8/_modules/torchvision/datasets/fakedata.html#FakeData in function __getitem__ target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0] It's long tensor. It should be int.

cc @pmeier @fmassa @vfdev-5

fmassa commented 3 years ago

Yes, the target should be a int, and it was the case in the first version of torchvision, but with the introduction of scalar tensors in PyTorch that snippet became a tensor.

I'm happy to accept a PR adding a .item() call to the aforementioned line to fix the issue

avijit9 commented 3 years ago

Anybody working on this? I can send a PR otherwise.

pmeier commented 3 years ago

@avijit9 Go ahead!

avijit9 commented 3 years ago

@pmeier Shouldn't this issue be closed?

pmeier commented 3 years ago

Indeed it should. If the PR contains a certain keyword together with the issue number, GitHub will close the issue automatically when the PR is merged.

You used a keyword in your PR that is not recognized by GitHub:

solves #3517