lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.39k stars 255 forks source link

Improve ComplexConv2d FSDP compatibility #193

Closed haydenshively closed 1 year ago

haydenshively commented 1 year ago

I couldn't get the existing ComplexConv2d workaround to behave with FSDP; I kept getting the following:

audiolm_pytorch/soundstream.py", line 194, in forward
    weight, bias = map(torch.view_as_complex, (self.weight, self.bias))
RuntimeError: Tensor must have a storage_offset divisible by 2

This change fixes that and works with any version of torch >= 1.12 (CUDA 11.8), at least On My Machine™️

Unfortunately it's slower than the built-in nn.Conv2d(..., dtype=torch.complex64), so anyone who doesn't care about sharding should stick with that. As such, I'm not sure if we should actually merge this or not, but wanted to drop it here for visibility.

Last thing: I haven't checked for numerical equivalency between this impl and the old one, so feedback/help on that would be appreciated.

lucidrains commented 1 year ago

i really regret trying to make this complex valued discriminator work, when Encodec clearly showed it was not necessary for getting good results

i think what i'll do is assert that FSDP is disabled for soundstream training

haydenshively commented 1 year ago

Sounds good. Can confirm that decoded audio becomes intelligible after training on ~800 hours of audio*, regardless of whether STFT network is real/complex.

*I've only run a couple seeds so far, so YMMV -- speaking of, do you have any benchmarks?

lucidrains commented 1 year ago

this should be addressed https://github.com/lucidrains/audiolm-pytorch/commit/cc36361a55a09755ec3d457beb9dbc1effb6cbb2

for examples that researchers have trained and decided to share, you can look in the discussions