lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.33k stars 249 forks source link

ComplexSTFTDiscriminator diverges from paper #174

Closed eagomez2 closed 1 year ago

eagomez2 commented 1 year ago

The ComplexSTFTDiscriminator in: https://github.com/lucidrains/audiolm-pytorch/blob/e28ff435104792c887b10eae10c83c4ca50880a9/audiolm_pytorch/soundstream.py#L208 seems to diverge from the paper description. Particularly, Figure 4 in https://arxiv.org/pdf/2107.03312.pdf seems to show that there is a skip connection in the form of addition in every ResidualUnit instance that is not implemented in ComplexResidualUnit here:

https://github.com/lucidrains/audiolm-pytorch/blob/e28ff435104792c887b10eae10c83c4ca50880a9/audiolm_pytorch/soundstream.py#L198

On the other end, I am not sure if I am missing something, but it also seems that such skip connection is not possible due to the different channel numbers of consecutive ResidualUnit blocks. For example, the first block in Figure 4 has N=C but the one right after has N=2C. This means, that the output of the second block will have twice the number of channels of the first block, and therefore those two tensors cannot be added at the end of the second block in the form of the proposed skip connection.

Any clarification about this would be really helpful.

lucidrains commented 1 year ago

I actually came across this and wasn't sure, and made the call to omit the residual. Our options would be a parallel conv with kernel size 2 stride 2, projecting to 2C, or a residual added right after the activation.

Do you know how encodec built theirs?

lucidrains commented 1 year ago

@eagomez2 i'm just going to do some improvisation https://github.com/lucidrains/audiolm-pytorch/commit/d50ae9e31601f9737dbf05cac3b27816641382c9

there! now there's a residual :laughing: added an extra conv because i haven't seen many residuals that are placed right after an activation, but let me know if you think otherwise

eagomez2 commented 1 year ago

Hi @lucidrains and thanks!

I am not really sure if these skip connections are fundamental in discriminators. My impression is that the identity propagation promoted by skip connections is not necessary for the task they do. Moreover, this is the first discriminator I bumped into that has them, although I may be wrong about the first statement.

On the other side, I looked into EnCodec because you mentioned it and found that even when the discriminator is not exactly the same, the approach is quite similar, but using a stack of real data types instead of complex numbers and with some differences in dilation and such. Here you can see the implementation: https://github.com/facebookresearch/encodec/blob/main/encodec/msstftd.py

Have you already successfully used the ComplexSTFTDiscriminator as it was without skip connections? I'm considering using it and comparing it with the EnCodec one to see which one works best for the task at hand.

turian commented 1 year ago

BTW, given the Encodec license change I wanted to point out that in a separate audio GAN I'm building, I found that the Encodec MSSTFT learned much more effectively than the bigvgan one.

MSSTFT is a complex multi resolution spectrogram discriminator.

With that said, it's not always clear we want better discriminators, if the generator can't keep up :) But it seemed to work for encodec

lucidrains commented 1 year ago

@eagomez2 so in the image arena, some of the more powerful discriminators definitely have residuals (stylegan2 etc)

yea, i've always thought that complex value network was a bit overkill, but during the early stages of the repository, some independent researchers reported better convergence, so i bit the bullet and made it work

i think you should just try the new version with the extra conv and skip connection. but honestly i'm fairly confident either case will work well