nupurkmr9 / vision-aided-gan

Ensembling Off-the-shelf Models for GAN Training (CVPR 2022 Oral)
https://www.cs.cmu.edu/~vision-aided-gan/
MIT License
382 stars 26 forks source link

Types of input/output in each discriminator #3

Closed naoki7090624 closed 2 years ago

naoki7090624 commented 2 years ago

Hi,

Thank you for sharing great works!! I would like to use the pretrained discriminator with my scratch discriminator for improving my model. I added the discriminator of vision-aided-gan with cv_type is swin, vgg or clip. (My code is structured similar to edge-connect).

self.discr = vision_aided_loss.Discriminator(cv_type='swin', loss_type='sigmoid', device=config.DEVICE).to(config.DEVICE)
self.discr.cv_ensemble.requires_grad_(False)

When I input the generated images (BCH*W) and ground truth images into the discriminator, I got the following lossD from vgg and swin.

tensor([[1.3401],
        [1.3370],
        [1.2983],
        [1.2942],
        [1.1943],
        [1.3307],
        [1.2072],
        [1.2092]], device='cuda:0', grad_fn=<AddBackward0>)

I could back propagate it by taking the average.

dis_loss = dis_real_loss + dis_fake_loss + torch.mean(lossD)

But, I got the following error from clip.

Traceback (most recent call last):
  File "/home/naoki/MyProject/src/models.py", line 434, in process
    lossD = self.discr(dis_input_real, for_real=True) + self.discr(dis_input_fake, for_real=False)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/vision_aided_loss/cv_discriminator.py", line 187, in forward
    return self.loss_type(pred_mask, **kwargs)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/vision_aided_loss/cv_losses.py", line 104, in forward
    loss_ = self.losses[i](input[i], **kwargs)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/naoki/.pyenv/versions/3.8.6/lib/python3.8/site-packages/vision_aided_loss/cv_losses.py", line 21, in forward
    target_ = target.expand_as(input).to(input.device)
TypeError: expand_as(): argument 'other' (position 1) must be Tensor, not lis

I am not familiar with these pretrained models. What are the input and output types for each discriminator?

Thank you in advance.

nupurkmr9 commented 2 years ago

Hi, thanks for the interest in our code.

In case of CLIP network, default discriminator architecture architecture is multi-level, therefore loss_type should be multilevel_sigmoid_s. If you want to use sigmoid_s loss_type with single-level discriminator architecture, passing output_type as conv in the arguments should enable that.

Let me know if this resolves the issue.