lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.36k stars 255 forks source link

Missing skip connection for cross attention? #150

Closed lonzi closed 1 year ago

lonzi commented 1 year ago

Hi, thanks for the implementation, it is really helpful!

While using the Transformer class, I noticed a significant performance degradation when activating the optional cross attention branch. I suspect this is due to a missing skip connection in the following line of code: https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/audiolm_pytorch.py#L456

I suggest changing to x = cross_attn(x, context = context, mask = context_mask) + x

lucidrains commented 1 year ago

@lonzi i was just trying to see if you were paying attenti.. just kidding

that was one of my high school teacher's favorite line :laughing:

thank you for catching this Alon!

lucidrains commented 1 year ago

are you doing TTS?

lonzi commented 1 year ago

Thanks for fixing! No, I work on drums generation and using the mentioned transformer above a pre-trained EnCodec model.