autonomousvision / stylegan-xl

[SIGGRAPH'22] StyleGAN-XL: Scaling StyleGAN to Large Diverse Datasets
MIT License
962 stars 112 forks source link

TypeError while trying to use the Discriminator #77

Open convexalpha opened 2 years ago

convexalpha commented 2 years ago

I am trying to use the Discriminator in the following way:

device = torch.device('cuda')
network_pkl = 'https://s3.eu-central-1.amazonaws.com/avg-projects/stylegan_xl/models/imagenet256.pkl'
with dnnlib.util.open_url(network_pkl) as f:
    G = legacy.load_network_pkl(f)['G'].to(device)

with dnnlib.util.open_url(network_pkl) as f:
    D = legacy.load_network_pkl(f)['D'].to(device)

z = torch.randn([4, 64], device=device)
label_idx = 281
label_shape = 1000
c = torch.zeros([4, label_shape], device=device)
c[:, label_idx] = 1

fake_image = G(z, c)
logits = D(fake_image, c)

But this returns the error:

TypeError Traceback (most recent call last) /mnt/Data1/vmisra/stylegan_xl/gan_sketching_nb.ipynb Cell 5' in <cell line: 3>() 1 # Print torch summary of D 2 fake_image = G(z, c) ----> 3 logits = D(fake_image, c)

File ~/anaconda3/envs/sgxl/lib/python3.9/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File /mnt/Data1/vmisra/stylegan_xl/pg_modules/discriminator.py:207, in ProjectedDiscriminator.forward(self, x, c) 204 x_aug = x_aug.add(1).div(2) 206 # apply F-specific normalization --> 207 x_n = Normalize(feat.normstats['mean'], feat.normstats['std'])(x_aug) 209 # upsample if smaller, downsample if larger + VIT 210 if self.interp224 or bb_name in VITS:

TypeError: 'NoneType' object is not subscriptable

RichardSunnyMeng commented 10 months ago

I meet the same problem. I check the source code and it seems that the relative params are not loaded by load_network_pkl. You can add the following code to function "forward" in ProjectedDiscriminator: if bb_name == "tf_efficientnet_lite0": mean = [0.5, 0.5, 0.5] std = [0.5, 0.5, 0.5] elif bb_name == "deit_base_distilled_patch16_224": mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] and modify the norm code to x_n = Normalize(mean, std)(x_aug).