zengchang233 / xiaoicesing2

The source code for the paper XiaoiceSing2 (interspeech2023)
BSD 3-Clause "New" or "Revised" License
43 stars 3 forks source link

Efficiency issue of discriminator #1

Open KakaruHayate opened 3 weeks ago

KakaruHayate commented 3 weeks ago

Hello, thank you very much for your work

I am trying to transplant the MB-discriminator to other projects, but the discriminator is very time-consuming during training.

if __name__ == "__main__":
    inputs = torch.randn(16, 400, 128).cuda()
    tgt = torch.randn(16, 400, 128).cuda()
    # inputs_len = torch.full((inputs.size(0),), inputs.size(1), dtype=torch.long)
    inputs_len = None

    n_mel = 128
    net = Discriminator(n_mel=n_mel).cuda()
    print(net)
    start_time = time.time()
    outputs, random_n = net(inputs, inputs_len, tgt)
    outputs, _ = net(inputs.detach(), inputs_len, tgt, random_n)
    Discriminator_time = time.time() - start_time
    print(f"Discriminator time: {Discriminator_time:.4f}s")

I used the above method for testing, and the discriminator takes 2 seconds each time. Does this comply with the design? Or are there still some parts that have not been updated? I am testing on RTX A4000 (16g VRAM).

zengchang233 commented 3 weeks ago

@KakaruHayate Hi, thanks for your interest in our project. This discriminator is composed of multiple sub-discriminators. Since it applies multi-window and multi-band analysis simultaneously, the number of the total sub-discriminators is (number of bands) (number of windows) and in the paper, this number is configured as 35=15. Besides, each sub-discriminator has an accompanied PatchGAN discriminator. Therefore, there are a total of 30 sub-discriminators in the model.

You can reduce the number of windows to improve the speed.

KakaruHayate commented 3 weeks ago

@KakaruHayate Hi, thanks for your interest in our project. This discriminator is composed of multiple sub-discriminators. Since it applies multi-window and multi-band analysis simultaneously, the number of the total sub-discriminators is (number of bands) (number of windows) and in the paper, this number is configured as 35=15. Besides, each sub-discriminator has an accompanied PatchGAN discriminator. Therefore, there are a total of 30 sub-discriminators in the model.

You can reduce the number of windows to improve the speed.

Thank you for your reply. I noticed that time_lengths and freq_lengths were not actually used Because they are not necessary?

zengchang233 commented 3 weeks ago

Please check the MultiWindowDiscriminator class. The time_lengths is used to initialize multiple DiscriminatorFactory instances. These instances are bundled to the self.conv_layers attribute.