Closed Ha0Tang closed 4 years ago
Hi, can you let me know when will you release the GauGAN code? Thanks.
We will release our compressed model of GauGAN and the test codes in 2 or 3 days. The training codes may be later. We are trying to merge the training codes into our repository and it may take some time.
Hi, we have released our compressed model of GauGAN and the test codes. Check the README for using our compressed model.
@lmxyy while waiting for your official SPADE release, I'm trying to fill out the gaps myself. It mostly going successful, however one thing I can't understand is how to adopt weight transfer to SPADE blocks
what I'm trying is this:
idxs = transfer_Conv2d(m1.conv_0, m2.conv_0, input_index=input_index)
idxs = transfer_Conv2d(m1.conv_1, m2.conv_1, input_index=idxs, output_index=input_index)
if m1.learned_shortcut and m2.learned_shortcut:
transfer_Conv2d(m1.conv_s, m2.conv_s, input_index=input_index)
but I constantly getting index out of bounds errors. I think the error comes from the fact that each SPADE block shrinks the number of channels it has. Could you please share a snippet on how to transfer weights from teacher SPADE block to a student one? Thanks!
Edit:
I tried changing the snippet to:
idxs = transfer_Conv2d(m1.conv_0, m2.conv_0, input_index=input_index)
idxs = transfer_Conv2d(m1.conv_1, m2.conv_1, input_index=idxs)
if m1.learned_shortcut and m2.learned_shortcut:
transfer_Conv2d(m1.conv_s, m2.conv_s, input_index=input_index)
and now it passes, but I'm not sure if it really works.
Edit 2:
It also feels like you have a typo here in transfer_Conv2d
implementation:
if input_index is not None:
q = p.abs().sum([0, 2, 3])
_, idxs = q.topk(m2.in_channels, largest=True)
p = p[:, idxs]
else:
p = p[:, input_index]
should it be is None
?
Yes, Edit 2 is a typo. Thank you for pointing it out.
Here is the snippet of my implementation of weight transfering of MobileSPADEGenerator, but I haven't sorted it. I hope this could help you:
def transfer_conv(m1, m2, input_index, output_index=None):
assert isinstance(m1, nn.Conv2d) and isinstance(m2, nn.Conv2d)
p = m1.weight.data
assert input_index is not None
p = p[:, input_index]
if output_index is None:
q = p.abs().sum([1, 2, 3])
_, idxs = q.topk(m2.out_channels, largest=True)
else:
idxs = output_index
m2.weight.data = p[idxs].clone()
if m2.bias is not None:
m2.bias.data = m1.bias.data[idxs].clone()
return idxs
def transfer_spconv(m1, m2, input_index, output_index=None):
assert isinstance(m1, SeparableConv2d) and isinstance(m2, SeparableConv2d)
def transfer_dw(dw1, dw2):
p = dw1.weight.data
# print(input_index.max(), p.shape)
dw2.weight.data = p[input_index].clone()
if dw2.bias is not None:
dw2.bias.data = dw1.bias.data[input_index].clone()
def transfer_pw(pw1, pw2):
p = pw1.weight.data
# print('!!!', input_index.max(), p.shape)
p = p[:, input_index]
if output_index is None:
q = p.abs().sum([1, 2, 3])
_, idxs = q.topk(pw2.out_channels, largest=True)
else:
idxs = output_index
pw2.weight.data = p[idxs].clone()
if pw2.bias is not None:
pw2.bias.data = pw1.bias.data[idxs].clone()
return idxs
transfer_dw(m1.conv[0], m2.conv[0])
idxs = transfer_pw(m1.conv[2], m2.conv[2])
return idxs
def transfer_mbspade(m1, m2, input_index=None):
assert isinstance(m1, MobileSPADE) and isinstance(m2, MobileSPADE)
m2.param_free_norm.running_mean = m1.param_free_norm.running_mean[input_index].clone()
m2.param_free_norm.running_var = m1.param_free_norm.running_var[input_index].clone()
idxs = transfer_conv(m1.mlp_shared[0], m2.mlp_shared[0], list(range(m1.mlp_shared[0].in_channels)))
transfer_spconv(m1.mlp_gamma, m2.mlp_gamma, idxs, input_index)
transfer_spconv(m1.mlp_beta, m2.mlp_beta, idxs, input_index)
return input_index
def transfer_mbresnetblock1(m1, m2, input_index):
assert input_index is not None
assert isinstance(m1, MobileSPADEResnetBlock) and isinstance(m2, MobileSPADEResnetBlock)
if m1.learned_shortcut:
assert m2.learned_shortcut
idxs = transfer_mbspade(m1.norm_0, m2.norm_0, input_index)
idxs = transfer_conv(m1.conv_0, m2.conv_0, idxs)
idxs = transfer_mbspade(m1.norm_1, m2.norm_1, idxs)
idxs = transfer_conv(m1.conv_1, m2.conv_1, idxs)
# print(len(idxs))
transfer_mbspade(m1.norm_s, m2.norm_s, input_index)
transfer_conv(m1.conv_s, m2.conv_s, input_index, idxs)
return idxs
else:
assert not m2.learned_shortcut
idxs = transfer_mbspade(m1.norm_0, m2.norm_0, input_index)
idxs = transfer_conv(m1.conv_0, m2.conv_0, idxs)
idxs = transfer_mbspade(m1.norm_1, m2.norm_1, idxs)
idxs = transfer_conv(m1.conv_1, m2.conv_1, idxs, input_index)
return idxs
def transfer_weight(netA, netB):
if isinstance(netA, MobileSPADEGenerator):
assert isinstance(netB, MobileSPADEGenerator)
idxs = transfer_conv(netA.fc, netB.fc, list(range(netA.fc.in_channels)))
idxs = transfer_mbresnetblock1(netA.head_0, netB.head_0, idxs)
idxs = transfer_mbresnetblock1(netA.G_middle_0, netB.G_middle_0, idxs)
idxs = transfer_mbresnetblock1(netA.G_middle_1, netB.G_middle_1, idxs)
idxs = transfer_mbresnetblock1(netA.up_0, netB.up_0, idxs)
idxs = transfer_mbresnetblock1(netA.up_1, netB.up_1, idxs)
idxs = transfer_mbresnetblock1(netA.up_2, netB.up_2, idxs)
idxs = transfer_mbresnetblock1(netA.up_3, netB.up_3, idxs)
else:
raise NotImplementedError
@lmxyy thanks a lot! I will now try to replace it and see how it goes
@lmxyy, any estimate for when you might release the GauGAN training codes? Thank you!
@lmxyy, any estimate for when you might release the GauGAN training codes? Thank you!
We will release the training codes in one or two weeks.
Our GauGAN training codes have been released. You could check training_tutorial.md to set up GauGAN experiments.
We will add our GauGAN model later. Stay tuned.