Open steven5clu884 opened 3 years ago
you need to change you pytorch to 0.4.
Or you can just do as follows:
change the CPDataset class in cp_dataset.py
self.transform = transforms.Compose([ \ transforms.ToTensor(), \ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
to
self.transform = transforms.Compose([ \ transforms.ToTensor(), \ transforms.Normalize((0.5,), (0.5,))])
Hi,
I am having the same problem, I couldn't find torch==0.4. I believe it is deprecated (please correct me if wrong). Changing the Normalization code didn't work either as I am still getting the same error
Thanks in advance!
If you updated your torch .view
needs to be replaced with .reshape
on network.py on line 135.
In cp_dataset.py, line 30
Replace:
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
With this line:
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
And also in network.py, line 35 Replace:
x = x.view(x.size(0), -1)
With this line:
x = x.reshape(x.size(0), -1)
I clone the repo, and then create an anaconda environment I then type in python test.py I have checkpoints ── GMM │ └── gmm_final.pth └── TOM └── tom_final.pth
My traceback looks like Traceback (most recent call last): File "test.py", line 229, in
main()
File "test.py", line 215, in main
test_gmm(opt, test_loader, model, board)
File "test.py", line 86, in test_gmm
for step, inputs in enumerate(test_loader.data_loader):
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in next
data = self._next_data()
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
return self._process_data(data)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
normalize1: torch.Size([3, 256, 192])
data.reraise()
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
data = fetcher.fetch(index)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/wizzerking/develop/cp-vton-plus/cp_dataset.py", line 150, in getitem
shape_ori = self.transform(parse_shape_ori) # [-1,1]
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in call
img = t(img)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forwardcall(*input, **kwargs)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 226, in forward
return F.normalize(tensor, self.mean, self.std, self.inplace)
File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 352, in normalize
print("normalize1:",tensor.sub(mean).div_(std).shape)
RuntimeError: output with shape [1, 256, 192] doesn't match the broadcast shape [3, 256, 192]
Things I have tried in cp_dataset.py add .convert('RGB') after each call top Image.open so for instance
This does not help; also adding to the person image
person image
does not help