minar09 / cp-vton-plus

Official implementation for "CP-VTON+: Clothing Shape and Texture Preserving Image-Based Virtual Try-On", CVPRW 2020
https://minar09.github.io/cpvtonplus/
MIT License
343 stars 120 forks source link

RuntimeError: output with shape [1, 256, 192] doesn't match the broadcast shape [3, 256, 192] #81

Open steven5clu884 opened 2 years ago

steven5clu884 commented 2 years ago

I clone the repo, and then create an anaconda environment I then type in python test.py I have checkpoints ── GMM │   └── gmm_final.pth └── TOM └── tom_final.pth

My traceback looks like Traceback (most recent call last): File "test.py", line 229, in main() File "test.py", line 215, in main test_gmm(opt, test_loader, model, board) File "test.py", line 86, in test_gmm for step, inputs in enumerate(test_loader.data_loader): File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in next data = self._next_data() File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data return self._process_data(data) File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data normalize1: torch.Size([3, 256, 192]) data.reraise() File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise raise exception RuntimeError: Caught RuntimeError in DataLoader worker process 0. Original Traceback (most recent call last): File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop data = fetcher.fetch(index) File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in data = [self.dataset[idx] for idx in possibly_batched_index] File "/home/wizzerking/develop/cp-vton-plus/cp_dataset.py", line 150, in getitem shape_ori = self.transform(parse_shape_ori) # [-1,1] File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in call img = t(img) File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forwardcall(*input, **kwargs) File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 226, in forward return F.normalize(tensor, self.mean, self.std, self.inplace) File "/home/wizzerking/develop/anaconda3/envs/cpvton/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 352, in normalize print("normalize1:",tensor.sub(mean).div_(std).shape) RuntimeError: output with shape [1, 256, 192] doesn't match the broadcast shape [3, 256, 192]

Things I have tried in cp_dataset.py add .convert('RGB') after each call top Image.open so for instance

   if self.stage == 'GMM':
        c = Image.open(osp.join(self.data_path, 'cloth', c_name))
        c = c.convert('RGB')

        cm = Image.open(osp.join(self.data_path, 'cloth-mask', c_name)).convert('L')
        cm = cm.convert('RGB')

This does not help; also adding to the person image

person image

    im = Image.open(osp.join(self.data_path, 'image', im_name))
    im = im.convert('RGB')

does not help

rocketeerli commented 2 years ago

you need to change you pytorch to 0.4. Or you can just do as follows: change the CPDataset class in cp_dataset.py self.transform = transforms.Compose([ \ transforms.ToTensor(), \ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) to self.transform = transforms.Compose([ \ transforms.ToTensor(), \ transforms.Normalize((0.5,), (0.5,))])

slaifan commented 2 years ago

Hi,

I am having the same problem, I couldn't find torch==0.4. I believe it is deprecated (please correct me if wrong). Changing the Normalization code didn't work either as I am still getting the same error

Thanks in advance!

zakmicallef commented 1 year ago

If you updated your torch .view needs to be replaced with .reshape on network.py on line 135.

niranjanakella commented 1 year ago

In cp_dataset.py, line 30

Replace:

self.transform = transforms.Compose([
             transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

With this line:

self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,))])

And also in network.py, line 35 Replace:

x = x.view(x.size(0), -1)

With this line:

x = x.reshape(x.size(0), -1)