Closed LWprogramming closed 1 year ago
@LWprogramming thanks for raising a valid issue again! it was exactly what you thought it was, the computing of the 2d relative positional bias in the fine transformer
do you want to see if this commit helps? (would also appreciate a code review, as this has always been super confusing for me :dizzy_face:)
@LWprogramming if this doesn't work, i may just have to turn to other techniques to generate relative positions across the quantize positions
Hi @lucidrains, I am facing the same issue of Cuda OOM training SoundStream on 24 GB GPU memory. it runs on 2 second audio with batch size 4. is it expected behavior?
@akmalmasud96 yes for soundstream that is expected. but we are talking about fine transformer here
@lucidrains yes, I understand that this is about the fine transformer. I guessed that this is due to the recent updates. Could you please guide me on which part of the model consumes that much memory. 1 second audio takes 4 GB memory, 2 seconds audio takes 5.5 GB ( with the batch size of 1). Could you please illustrate the reason for such a behaviour?
do you want to see if this commit helps? (would also appreciate a code review, as this has always been super confusing for me 😵)
Fine Transformer is definitely substantially faster now (probably in-line with 0.22.3 and memory usage at a glance seems pretty similar, definitely in the 20-30k MiB memory usage range as before)! I'm also new to this, so I watched some videos for how relative 2-d encodings work but I'm still working through some confusion. Is there a specific paper you're following here?
@LWprogramming nice! thanks for testing out the changes and confirming it is fixed!
i think it was in some old vision transformers paper, but forgot the name of the paper. no worries, i'll just review the code once or twice later this week
Using this script which tries running AudioLM end-to-end (basically the demo notebook in a single python script), I found that 0.23.2 reliably resulted in OOM on the A100 GPU during inference when fine transformer got to around 93%. I just tried it again on 0.22.3 and it was able to finish (and by rough approximation i.e. running
nvidia-smi
repeatedly in a separate terminal), I found the older version uses about 10k MiB less memory. Has anyone else been able to replicate this?I haven't looked closely at the code changes yet so I don't have a hypothesis for what causes this, but I thought to look at the most recent change because the OOM stack trace included
audiolm_pytorch.py", line 955, in forward
:attn_bias = self.pos_bias_mlp(rel_dist.float())
which was introduced here.