eriklindernoren / PyTorch-GAN

PyTorch implementations of Generative Adversarial Networks.
MIT License
16.17k stars 4.05k forks source link

Can I do train grayscale-image on MUNIT ?? #170

Open edwardcho opened 2 years ago

edwardcho commented 2 years ago

Hello Sir,

I have interesting image-to-image translation. So I tried to your code using my-datasets.

My-datasets are as follows :

  1. grayscale (1 channel)
  2. 256 x 256

When start training, I met some error.

Namespace(b1=0.5, b2=0.999, batch_size=4, channels=1, checkpoint_interval=-1, dataset_name='noise2clip', decay_epoch=2, dim=64, epoch=0, img_height=256, img_width=256, lr=0.0001, n_cpu=8, n_downsample=2, n_epochs=4, n_residual=3, sample_interval=400, style_dim=8)
/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py:288: UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
  "Argument interpolation should be of type InterpolationMode instead of int. "
../../data/noise2clip/trainA
../../data/noise2clip/valA
Traceback (most recent call last):
  File "munit.py", line 171, in <module>
    for i, batch in enumerate(dataloader):
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 521, in __next__
    data = self._next_data()
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1203, in _next_data
    return self._process_data(data)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1229, in _process_data
    data.reraise()
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/_utils.py", line 434, in reraise
    raise exception
RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 49, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/data/TESTBOARD/additional_networks/generation/PyTorch-GAN_eriklindernoren/implementations/munit/datasets.py", line 40, in __getitem__
    img_A = self.transform(img_A)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 61, in __call__
    img = t(img)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 226, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/home/itsme/anaconda3/lib/python3.7/site-packages/torchvision/transforms/functional.py", line 351, in normalize
    tensor.sub_(mean).div_(std)
RuntimeError: output with shape [1, 256, 256] doesn't match the broadcast shape [3, 256, 256]

How to train on my-case (using grayscale-datam MUNIT) ?

Thanks. Edward Cho.

Ignoramus-Sage commented 2 years ago

This is my hypothesis: The error stems from that your images are of dimension 1 along the channels dimension. What you can do is add two channels (stack the same image along the channel axis duplicating the values you already have, R=G=B=Gray). It should work however I am not sure about performance as I have no idea on the paper or the code.

algocompretto commented 2 years ago

You are trying to pass an RGB image to the model, instead, apply grayscale tranformation to your data loading pipeline.

import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])

dataloader = torch.utils.data.DataLoader(
    ImageFolder("dataset_folder_path/",
     transform=transform),
    batch_size=16, shuffle=True)