Closed zhifengkong closed 3 years ago
Yes, the target should be a int
, and it was the case in the first version of torchvision, but with the introduction of scalar
tensors in PyTorch that snippet became a tensor
.
I'm happy to accept a PR adding a .item()
call to the aforementioned line to fix the issue
Anybody working on this? I can send a PR otherwise.
@avijit9 Go ahead!
@pmeier Shouldn't this issue be closed?
Indeed it should. If the PR contains a certain keyword together with the issue number, GitHub will close the issue automatically when the PR is merged.
You used a keyword in your PR that is not recognized by GitHub:
solves #3517
🐛 Bug
If you concatenate a dataset such as CIFAR10 with FakeData, you get error
To Reproduce
Steps to reproduce the behavior:
cifar_dataset = torchvision.datasets.CIFAR10(...)
fake_dataset = torchvision.datasets.FakeData(...)
train_data = Concat([cifar_dataset, fake_dataset])
train_loader = DataLoader(train_data, ...)
for data in train_loader
then errorAdditional context
The reason why it happens is the labels in CIFAR10 are int and labels in FakeData are tensors. When concatenating them to construct a batch, the batch labels look like [0,1,2,3,tensor(0),3,4,5,6,tensor(2)...].
I can solve this bug by letting
target_transform=int
when I load fake_dataset. However, this is very hard to debug. I assume that the default target type in the FakeData source code should be set to int instead of long tensor.Here: https://pytorch.org/vision/0.8/_modules/torchvision/datasets/fakedata.html#FakeData in function
__getitem__
target = torch.randint(0, self.num_classes, size=(1,), dtype=torch.long)[0]
It's long tensor. It should be int.cc @pmeier @fmassa @vfdev-5