lucidrains / audiolm-pytorch

Implementation of AudioLM, a SOTA Language Modeling Approach to Audio Generation out of Google Research, in Pytorch
MIT License
2.36k stars 255 forks source link

v0.23.2 causing OOM #125

Closed LWprogramming closed 1 year ago

LWprogramming commented 1 year ago

Using this script which tries running AudioLM end-to-end (basically the demo notebook in a single python script), I found that 0.23.2 reliably resulted in OOM on the A100 GPU during inference when fine transformer got to around 93%. I just tried it again on 0.22.3 and it was able to finish (and by rough approximation i.e. running nvidia-smi repeatedly in a separate terminal), I found the older version uses about 10k MiB less memory. Has anyone else been able to replicate this?

I haven't looked closely at the code changes yet so I don't have a hypothesis for what causes this, but I thought to look at the most recent change because the OOM stack trace included audiolm_pytorch.py", line 955, in forward: attn_bias = self.pos_bias_mlp(rel_dist.float()) which was introduced here.

lucidrains commented 1 year ago

@LWprogramming thanks for raising a valid issue again! it was exactly what you thought it was, the computing of the 2d relative positional bias in the fine transformer

do you want to see if this commit helps? (would also appreciate a code review, as this has always been super confusing for me :dizzy_face:)

lucidrains commented 1 year ago

@LWprogramming if this doesn't work, i may just have to turn to other techniques to generate relative positions across the quantize positions

akmalmasud96 commented 1 year ago

Hi @lucidrains, I am facing the same issue of Cuda OOM training SoundStream on 24 GB GPU memory. it runs on 2 second audio with batch size 4. is it expected behavior?

lucidrains commented 1 year ago

@akmalmasud96 yes for soundstream that is expected. but we are talking about fine transformer here

akmalmasud96 commented 1 year ago

@lucidrains yes, I understand that this is about the fine transformer. I guessed that this is due to the recent updates. Could you please guide me on which part of the model consumes that much memory. 1 second audio takes 4 GB memory, 2 seconds audio takes 5.5 GB ( with the batch size of 1). Could you please illustrate the reason for such a behaviour?

LWprogramming commented 1 year ago

do you want to see if this commit helps? (would also appreciate a code review, as this has always been super confusing for me 😵)

Fine Transformer is definitely substantially faster now (probably in-line with 0.22.3 and memory usage at a glance seems pretty similar, definitely in the 20-30k MiB memory usage range as before)! I'm also new to this, so I watched some videos for how relative 2-d encodings work but I'm still working through some confusion. Is there a specific paper you're following here?

lucidrains commented 1 year ago

@LWprogramming nice! thanks for testing out the changes and confirming it is fixed!

i think it was in some old vision transformers paper, but forgot the name of the paper. no worries, i'll just review the code once or twice later this week