CDOTAD / AlphaGAN-Matting

This project is an unofficial implementation of AlphaGAN: Generative adversarial networks for natural image matting published at the BMVC 2018
152 stars 35 forks source link

trimap和alpha图与网络结构不匹配 #8

Closed shartoo closed 5 years ago

shartoo commented 5 years ago
WARNING:root:Setting up a new session...
INFO:tornado.access:200 POST /env/alphaGAN (127.0.0.1) 0.28ms
WARNING:visdom:Without the incoming socket you cannot receive events from the server or register event handlers to your Visdom client.
0it [00:00, ?it/s]Traceback (most recent call last):
  File "alphaGAN_train.py", line 80, in <module>
    main()
  File "alphaGAN_train.py", line 76, in main
    gan.train(dataset)
  File "/home/xxx/work/AlphaGANMatting/model/AlphaGAN.py", line 462, in train
    for ii, data in tqdm.tqdm(enumerate(dataset)):
  File "/usr/local/anaconda3/lib/python3.6/site-packages/tqdm/_tqdm.py", line 1022, in __iter__
    for obj in iterable:
  File "/home/xxx/work/AlphaGANMatting/data/__init__.py", line 23, in __iter__
    for i, data in enumerate(self.dataloader):
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 637, in __next__
    return self._process_next_batch(batch)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 658, in _process_next_batch
    raise batch.exc_type(batch.exc_msg)
RuntimeError: Traceback (most recent call last):
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/xxx/work/AlphaGANMatting/data/input_dataset.py", line 98, in __getitem__
    T = self.transform(trimap_img)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 60, in __call__
    img = t(img)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torchvision/transforms/transforms.py", line 163, in __call__
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/torchvision/transforms/functional.py", line 208, in normalize
    tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
RuntimeError: output with shape [1, 2278, 3138] doesn't match the broadcast shape [3, 2278, 3138]

我从http://alphamatting.com/datasets.php 这里下载的原图,请问是什么地方需要修改吗?

报错的相关代码

  input_img = Image.open(input_path).convert('RGB')
        trimap_img = Image.open(trimap_path)
        alpha_img = Image.open(alpha_path)
        bg_img = Image.open(bg_path).convert('RGB')
        fg_img = Image.open(fg_path).convert('RGB')

        #x, y = random_choice(trimap_img)

        I = self.transform(input_img)
        T = self.transform(trimap_img)
        A = self.transform(alpha_img)
        B = self.transform(bg_img)
        F = self.transform(fg_img)
CDOTAD commented 5 years ago

你试一试

     trimap_img = Image.open(trimap_path).convert('L')
     alpha_img = Image.open(alpha_path).convert('L')
shartoo commented 5 years ago

这个也不行。。

CDOTAD commented 5 years ago

你用的什么版本的pytorch。或许这个说法是对的output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

shartoo commented 5 years ago

按照那个说法 input_dataset.py中的76-79行应该改为

   self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

但是依然是 不匹配的错误。我用的是pytorch 1.0版本。

CDOTAD commented 5 years ago

你试试看最新的这次更新的input_dataset.py

shartoo commented 5 years ago

我记得昨天就是这样改过的,也在这里发布了,怎么没了。不过这个问题可以按照你最新的input_dataset.py来修改,可以的。