mit-han-lab / gan-compression

[CVPR 2020] GAN Compression: Efficient Architectures for Interactive Conditional GANs
Other
1.1k stars 148 forks source link

How to run GauGAN experiments? #1

Closed Ha0Tang closed 4 years ago

lmxyy commented 4 years ago

We will add our GauGAN model later. Stay tuned.

Ha0Tang commented 4 years ago

Hi, can you let me know when will you release the GauGAN code? Thanks.

lmxyy commented 4 years ago

We will release our compressed model of GauGAN and the test codes in 2 or 3 days. The training codes may be later. We are trying to merge the training codes into our repository and it may take some time.

lmxyy commented 4 years ago

Hi, we have released our compressed model of GauGAN and the test codes. Check the README for using our compressed model.

s1ddok commented 4 years ago

@lmxyy while waiting for your official SPADE release, I'm trying to fill out the gaps myself. It mostly going successful, however one thing I can't understand is how to adopt weight transfer to SPADE blocks

what I'm trying is this:

    idxs = transfer_Conv2d(m1.conv_0, m2.conv_0, input_index=input_index)
    idxs = transfer_Conv2d(m1.conv_1, m2.conv_1, input_index=idxs, output_index=input_index)
    if m1.learned_shortcut and m2.learned_shortcut:
        transfer_Conv2d(m1.conv_s, m2.conv_s, input_index=input_index)

but I constantly getting index out of bounds errors. I think the error comes from the fact that each SPADE block shrinks the number of channels it has. Could you please share a snippet on how to transfer weights from teacher SPADE block to a student one? Thanks!

Edit:

I tried changing the snippet to:

    idxs = transfer_Conv2d(m1.conv_0, m2.conv_0, input_index=input_index)
    idxs = transfer_Conv2d(m1.conv_1, m2.conv_1, input_index=idxs)
    if m1.learned_shortcut and m2.learned_shortcut:
        transfer_Conv2d(m1.conv_s, m2.conv_s, input_index=input_index)

and now it passes, but I'm not sure if it really works.

Edit 2:

It also feels like you have a typo here in transfer_Conv2d implementation:

        if input_index is not None:
            q = p.abs().sum([0, 2, 3])
            _, idxs = q.topk(m2.in_channels, largest=True)
            p = p[:, idxs]
        else:
            p = p[:, input_index]

should it be is None?

lmxyy commented 4 years ago

Yes, Edit 2 is a typo. Thank you for pointing it out.

Here is the snippet of my implementation of weight transfering of MobileSPADEGenerator, but I haven't sorted it. I hope this could help you:

def transfer_conv(m1, m2, input_index, output_index=None):
    assert isinstance(m1, nn.Conv2d) and isinstance(m2, nn.Conv2d)
    p = m1.weight.data
    assert input_index is not None
    p = p[:, input_index]
    if output_index is None:
        q = p.abs().sum([1, 2, 3])
        _, idxs = q.topk(m2.out_channels, largest=True)
    else:
        idxs = output_index
    m2.weight.data = p[idxs].clone()
    if m2.bias is not None:
        m2.bias.data = m1.bias.data[idxs].clone()
    return idxs

def transfer_spconv(m1, m2, input_index, output_index=None):
    assert isinstance(m1, SeparableConv2d) and isinstance(m2, SeparableConv2d)

    def transfer_dw(dw1, dw2):
        p = dw1.weight.data
        # print(input_index.max(), p.shape)
        dw2.weight.data = p[input_index].clone()
        if dw2.bias is not None:
            dw2.bias.data = dw1.bias.data[input_index].clone()

    def transfer_pw(pw1, pw2):
        p = pw1.weight.data
        # print('!!!', input_index.max(), p.shape)
        p = p[:, input_index]
        if output_index is None:
            q = p.abs().sum([1, 2, 3])
            _, idxs = q.topk(pw2.out_channels, largest=True)
        else:
            idxs = output_index
        pw2.weight.data = p[idxs].clone()
        if pw2.bias is not None:
            pw2.bias.data = pw1.bias.data[idxs].clone()
            return idxs

    transfer_dw(m1.conv[0], m2.conv[0])
    idxs = transfer_pw(m1.conv[2], m2.conv[2])
    return idxs

def transfer_mbspade(m1, m2, input_index=None):
    assert isinstance(m1, MobileSPADE) and isinstance(m2, MobileSPADE)
    m2.param_free_norm.running_mean = m1.param_free_norm.running_mean[input_index].clone()
    m2.param_free_norm.running_var = m1.param_free_norm.running_var[input_index].clone()
    idxs = transfer_conv(m1.mlp_shared[0], m2.mlp_shared[0], list(range(m1.mlp_shared[0].in_channels)))
    transfer_spconv(m1.mlp_gamma, m2.mlp_gamma, idxs, input_index)
    transfer_spconv(m1.mlp_beta, m2.mlp_beta, idxs, input_index)
    return input_index

def transfer_mbresnetblock1(m1, m2, input_index):
    assert input_index is not None
    assert isinstance(m1, MobileSPADEResnetBlock) and isinstance(m2, MobileSPADEResnetBlock)
    if m1.learned_shortcut:
        assert m2.learned_shortcut
        idxs = transfer_mbspade(m1.norm_0, m2.norm_0, input_index)
        idxs = transfer_conv(m1.conv_0, m2.conv_0, idxs)
        idxs = transfer_mbspade(m1.norm_1, m2.norm_1, idxs)
        idxs = transfer_conv(m1.conv_1, m2.conv_1, idxs)
        # print(len(idxs))
        transfer_mbspade(m1.norm_s, m2.norm_s, input_index)
        transfer_conv(m1.conv_s, m2.conv_s, input_index, idxs)
        return idxs
    else:
        assert not m2.learned_shortcut
        idxs = transfer_mbspade(m1.norm_0, m2.norm_0, input_index)
        idxs = transfer_conv(m1.conv_0, m2.conv_0, idxs)
        idxs = transfer_mbspade(m1.norm_1, m2.norm_1, idxs)
        idxs = transfer_conv(m1.conv_1, m2.conv_1, idxs, input_index)
        return idxs

def transfer_weight(netA, netB):
    if isinstance(netA, MobileSPADEGenerator):
        assert isinstance(netB, MobileSPADEGenerator)
        idxs = transfer_conv(netA.fc, netB.fc, list(range(netA.fc.in_channels)))
        idxs = transfer_mbresnetblock1(netA.head_0, netB.head_0, idxs)
        idxs = transfer_mbresnetblock1(netA.G_middle_0, netB.G_middle_0, idxs)
        idxs = transfer_mbresnetblock1(netA.G_middle_1, netB.G_middle_1, idxs)
        idxs = transfer_mbresnetblock1(netA.up_0, netB.up_0, idxs)
        idxs = transfer_mbresnetblock1(netA.up_1, netB.up_1, idxs)
        idxs = transfer_mbresnetblock1(netA.up_2, netB.up_2, idxs)
        idxs = transfer_mbresnetblock1(netA.up_3, netB.up_3, idxs)
    else:
        raise NotImplementedError
s1ddok commented 4 years ago

@lmxyy thanks a lot! I will now try to replace it and see how it goes

sidodan commented 4 years ago

@lmxyy, any estimate for when you might release the GauGAN training codes? Thank you!

lmxyy commented 4 years ago

@lmxyy, any estimate for when you might release the GauGAN training codes? Thank you!

We will release the training codes in one or two weeks.

lmxyy commented 4 years ago

Our GauGAN training codes have been released. You could check training_tutorial.md to set up GauGAN experiments.