autonomousvision / stylegan-xl

[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets
MIT License
961 stars 113 forks source link

RuntimeError: output with shape [8, 1, 1024, 1024] doesn't match the broadcast shape [8, 3, 1024, 1024] #101

Open fonzen1 opened 1 year ago

fonzen1 commented 1 year ago

I wish to train stylegan-xl model on grayscale images from my custom dataset without labels. However, it seems like the model is not fit for training on grayscale images. Where do I specifically customize this to take in grayscale images?

Full error message:

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/torch/multiprocessing/spawn.py", line 69, in _wrap
    fn(i, *args)
  File "/masters/stylegan-xl/train.py", line 49, in subprocess_fn
    training_loop.training_loop(rank=rank, **c)
  File "/stylegan-xl/training/training_loop.py", line 208, in training_loop
    misc.print_module_summary(D, [img, c])
  File "/masters/stylegan-xl/torch_utils/misc.py", line 236, in print_module_summary
    outputs = module(*inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1445, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/masters/stylegan-xl/pg_modules/discriminator.py", line 207, in forward
    x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1423, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/transforms/transforms.py", line 270, in forward
    return F.normalize(tensor, self.mean, self.std, self.inplace)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional.py", line 363, in normalize
    return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
  File "/usr/local/lib/python3.8/dist-packages/torchvision/transforms/functional_tensor.py", line 922, in normalize
    return tensor.sub_(mean).div_(std)
RuntimeError: output with shape [8, 1, 1024, 1024] doesn't match the broadcast shape [8, 3, 1024, 1024]