Closed YiingWei closed 1 year ago
我没法根据你这些信息debug。 出错的地方是style = self.style(style).unsqueeze(2).unsqueeze(3),也就是输入的ResMod里的style的维度和目标不一致。 你需要提供报错时的style.shape。
我看你的out = self.res[0](out, resstyles[:, 0], interp_weights[0])报错在dualstylegan.py的第196行, 但我的代码里这句话在第159行,你是不是改了代码导致的?
作者你好,当我在进行Progressive Fine-Tuning中的Stage 1 & 2: Pretrain DualStyleGAN on FFHQ时,输入命令:python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=8765 pretrain_dualstylegan.py --iter 3000 --batch 4 ./data/ffhq/lmdb/后,报错显示: Traceback (most recent call last): File "pretrain_dualstylegan.py", line 473, in
pretrain(args, loader, generator, discriminator, g_optim, d_optim, g_ema, encoder, vggloss, device, inject_index=7, savemodel=False)
File "pretrain_dualstylegan.py", line 189, in pretrain
fakeimg, = generator(noise, externalstyle, use_res=True, z_plus_latent=z_plus_latent)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, kwargs)
File "/root/DualStyleGAN-main/model/dualstylegan.py", line 196, in forward
out = self.res[0](out, resstyles[:, 0], interp_weights[0])
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, *kwargs)
File "/root/DualStyleGAN-main/model/dualstylegan.py", line 74, in forward
out = self.conv(self.norm(x, s))
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(input, kwargs)
File "/root/DualStyleGAN-main/model/dualstylegan.py", line 19, in forward
style = self.style(style).unsqueeze(2).unsqueeze(3)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
RuntimeError: mat1 dim 1 must match mat2 dim 0
一模一样的数据和代码,而且我也打印了相关的数据维度观察,感觉并没有哪里存在维度不匹配的情况,但仍然出现了这样的报错,应该如何解决呢?