lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

Soundstream discriminator clip_grad_norm - some params are not clipped. #256

Closed avihu111 closed 10 months ago

avihu111 commented 10 months ago

Awesome repo! Thanks a lot! I have a minor fix, which might make adversarial training more stable. clip_grad_norm_ only clips the stft discriminator gradients. I guess the intention was clipping the gradient from all discriminators?

I couldn't push a new branch to create a PR, so I'll just post the change here: Line: https://github.com/lucidrains/audiolm-pytorch/blob/898f6479aee10e5b0604a6a5a00b8a3fa6359521/audiolm_pytorch/trainer.py#L539

should change from:

        if exists(self.discr_max_grad_norm):
            self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm)

to:

        if exists(self.discr_max_grad_norm):
            self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm)
            self.accelerator.clip_grad_norm_(self.soundstream.discriminators.parameters(), self.discr_max_grad_norm)

Thanks!!

lucidrains commented 10 months ago

@avihu111 thanks Avihu! indeed i forgot to clip the main discriminator! thank you 🙏

avihu111 commented 10 months ago

Thanks @lucidrains ! I was also wondering if I could submit pull requests. I already added to my code:

Not sure what's the common practice :) Thanks again for this awesome resource!

lucidrains commented 10 months ago

yeah absolutely! open source means you can always send up your changes for the common public good

thank you!