PDillis / stylegan3-fun

Modifications of the official PyTorch implementation of StyleGAN3. Let's easily generate images and videos with StyleGAN2/2-ADA/3!
Other
230 stars 36 forks source link

Unidentified AssertionError When Using 'projector.py' #27

Closed YouliangHUANG closed 1 year ago

YouliangHUANG commented 1 year ago

Describe the bug Unidentified AssertionError when I run the projector.py.

To Reproduce Steps to reproduce the behavior:

  1. In the root directory of this project, and execute this command: "python projector.py --network=my-pretrained-models/StyleGAN2-Ada-DEM1024-CLAHE.pkl --cfg=stylegan2 --target=targets/RiverValley.png"
  2. See error

Expected behavior I don't know what I should expect to happen, but I definitely know there's something wrong.

Error Information Setting up PyTorch plugin "bias_act_plugin"... /home/MYUSERID/anaconda3/envs/pytorch180-A100/lib/python3.8/site-packages/scipy/init.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.3 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" projector.py:447: DeprecationWarning: LANCZOS is deprecated and will be removed in Pillow 10 (2023-07-01). Use Resampling.LANCZOS instead. target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) Done. Projecting in W latent space... Starting from W midpoint using 10000 samples... Setting up PyTorch plugin "upfirdn2d_plugin"... Done. Traceback (most recent call last): File "projector.py", line 549, in run_projection() # pylint: disable=no-value-for-parameter File "/home/MYUSERID/anaconda3/envs/pytorch180-A100/lib/python3.8/site-packages/click/core.py", line 1128, in call return self.main(args, kwargs) File "/home/MYUSERID/anaconda3/envs/pytorch180-A100/lib/python3.8/site-packages/click/core.py", line 1053, in main rv = self.invoke(ctx) File "/home/MYUSERID/anaconda3/envs/pytorch180-A100/lib/python3.8/site-packages/click/core.py", line 1395, in invoke return ctx.invoke(self.callback, ctx.params) File "/home/MYUSERID/anaconda3/envs/pytorch180-A100/lib/python3.8/site-packages/click/core.py", line 754, in invoke return __callback(args, kwargs) File "/home/MYUSERID/anaconda3/envs/pytorch180-A100/lib/python3.8/site-packages/click/decorators.py", line 26, in new_func return f(get_current_context(), *args, *kwargs) File "projector.py", line 456, in run_projection projected_w_steps, run_config = project( File "projector.py", line 178, in project synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) File "/home/MYUSERID/anaconda3/envs/pytorch180-A100/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, kwargs) File "", line 71, in forward AssertionError

Environment

BTW, I use the slurm to submit my work to the lab's server. I have successfully done the training on my own dataset. And the dataset is not about human faces, the images in my dataset are grayscale digital elevation maps (DEM) with a resolution of 1024x1024. This error is unidentified through the log. Any effort on solving this error is appreciated.

PDillis commented 1 year ago

Since this error comes from PyTorch's module, it's hard to assess, but I can see it comes from the pre-trained VGG16. The easiest solution or way for me to help is if you provide a sample image that you wish to project, no matter your trained model. I'll try to project to e.g. FFHQ and see what happens.

YouliangHUANG commented 1 year ago

I think I might know what happened. In projector.py line 169 - line 178:

    synth_images = G.synthesis(ws, noise_mode='const') # synth_images here is [1, 1, 256, 256]

    # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
    synth_images = (synth_images + 1) * (255/2)
    if synth_images.shape[2] > 256:
        synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

    # Features for synth images.
    if loss_paper == 'sgan2':
        synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) # pretrained VGG16 might require input with a shape of [1, 3, 256, 256]

Since my dataset includes 8bit grayscale images, the shape of "synth_images" is [1, 1, 256, 256] instead of [1, 3, 256, 256], and therefore the AssertError is triggered. Is there any possible solution that can fix this bug? Can we simply duplicate the contents of the only channel to the other two channels? ------------------------updated in 2022/11/19------------------------ I try the "torch.repeat_interleave()" to duplicate and reshape the tensor into the shape of [1, 3, 256, 256]. Then the "AssertError" is not triggered. In addition, the corresponding image saving codes need to be adjusted to adapt the grayscale image saving.

PDillis commented 1 year ago

Thanks for the update. Indeed, I hadn't thought of the models trained on non-RGB data for the projection, as this fork lets you also train with RGBA data. I think then I'll expand the code to let you use grayscale models for projection. I'll let you know when it's done so you do a git pull and test the code then.

I will basically use .repeat like so:

synth_img_shape = synt_images.shape
if synth_img_shape[1] == 1:
    synth_images = synth_images.repeat(1, 3, 1, 1)

However, if you want to do a pull request with this change (assuming it works), then I'd be happy to accept it :)

PDillis commented 1 year ago

d28d0af solves this, let me know if there are any more issues (I don't have grayscale models to fully test the correctness).

YouliangHUANG commented 1 year ago

Thanks for the efforts. I will do a pull request later to fix bugs on grayscale image saving.