lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

Hardly to train with 8s length audio for batch size of 2 #16

Closed jinhonglu closed 10 months ago

jinhonglu commented 10 months ago

Hi, according to the paper, We do not use the In-House dataset for ablation study. The effective batch size is 64 (i.e., 4 for each GPU) using accumulate grad batches=2. and Our model takes a segment of 8-seconds waveform for input and output. for the L=6 model. Based on my understanding, they train every 2 audio segments with 8s on each GPU without accumulate batch. However, I hardly to fit in while I am only training with one V100.

Another thing is the size of the model In the paper, they mentioned The numbers of parameters for BS-RoFormer and BS-Transformer with L=6 are 72.2M and 72.5M. While I following the default setting from the paper, my model is 108M

model = Model( dim=384, depth=6, time_transformer_depth=1, freq_transformer_depth=1, heads=8, attn_dropout=0.1, ff_dropout=0.1, dim_head=48, stereo=True )

Anything wrong?

jinhonglu commented 10 months ago

I changed the mask_estimator_model according to #15, the size of the model is now 72.1M, it is nearly close to the number mentioned in the paper.

However, it is still unable to fit 2 audio with 8s

lucidrains commented 10 months ago

@jinhonglu do you know about gradient accumulation?

lucidrains commented 10 months ago

@jinhonglu ah, you are unable to take advantage of flash attention because of your v100

jinhonglu commented 10 months ago

@lucidrains

@jinhonglu do you know about gradient accumulation?

yeah, that means the model runs multiple times with batches before going backward, that is why I say the actual batch size for each GPU is 2

@jinhonglu ah, you are unable to take advantage of flash attention because of your v100

yes, I can not use the flash attention, but in fact, for the L=6 model, they use 16 V100-32Gs to train as well. I am not sure why they can fit 2 audio with the length of 8s. Or anyone can fit in?

lucidrains commented 10 months ago

@jinhonglu i don't think this is an issue with the architecture. you'll need stretch your resources with precision, activation checkpointing, gradient accumulation etc

lucidrains commented 10 months ago

feel free to ask for help on laion or discussions tab