lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

MelRoformer parameters from paper #26

Open ZFTurbo opened 9 months ago

ZFTurbo commented 9 months ago

I'm trying to reproduce the paper model. But I have no luck.

My current settings which gave me batch size only 2 for 48 GB memory:

  dim: 192
  depth: 8
  stereo: true
  num_stems: 1
  time_transformer_depth: 1
  freq_transformer_depth: 1
  num_bands: 60
  dim_head: 64
  heads: 8
  attn_dropout: 0.1
  ff_dropout: 0.1
  flash_attn: True
  dim_freqs_in: 1025
  sample_rate: 44100  # needed for mel filter bank from librosa
  stft_n_fft: 2048
  stft_hop_length: 512
  stft_win_length: 2048
  stft_normalized: False
  mask_estimator_depth: 2
  multi_stft_resolution_loss_weight: 1.0
  multi_stft_resolutions_window_sizes: !!python/tuple
  - 4096
  - 2048
  - 1024
  - 512
  - 256
  multi_stft_hop_size: 147
  multi_stft_normalized: False

On input I give 8 seconds of 44100Hz so length is 352800.

I run my code model through torchinfo:

from torchinfo import summary
summary(model, input_size=(1, 2, 352768))

Report is:

==============================================================================================================
Layer (type:depth-idx)                                       Output Shape              Param #
==============================================================================================================
MelBandRoformer                                              [1, 2, 352768]            56,503,768
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─BandSplit: 1-2                                             [1, 690, 60, 192]         --
│    └─ModuleList: 2                                         --                        --
....
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─ModuleList: 1                                              --                        --
│    └─MaskEstimator: 2-9                                    [1, 690, 7916]            --
==============================================================================================================
Total params: 69,102,468
Trainable params: 69,102,404
Non-trainable params: 64
Total mult-adds (G): 8.35
==============================================================================================================
Input size (MB): 2.82
Forward/backward pass size (MB): 703.40
Params size (MB): 232.17
Estimated Total Size (MB): 938.40
==============================================================================================================

From report I expect to have batch more than 48. But in the end I can use batch only 2.

GPU memory usage for batch = 2: изображение

To follow the paper I must increase dim to 384, depth to 12 and decrease stft_hop_length to 441 - to be 10 ms. In this case batch size will be only 1 or not fit in memory )

Any ideas how to deal with such big memory usage?

lucidrains commented 9 months ago

there's a bunch of techniques

gradient accumulation is what you can immediately use, but also look into gradient checkpointing etc

ZFTurbo commented 9 months ago

I use gradient accumulation now. But in paper authors report model fit in 32 GB with batch size = 6.

lucidrains commented 9 months ago

how many GPUs do you have?

lucidrains commented 9 months ago

ah, well I can't help you with that, but this is a common issue and you aught to be able to figure it out with the resources online

lucidrains commented 9 months ago

the biggest gun I can bring out would be reversible networks, which could work for this architecture. maybe at a later date. for now, maybe accumulate gradients and wait it out?

lucidrains commented 9 months ago

are you using mixed precision with flash attention turned on?

ZFTurbo commented 9 months ago

I use mixed precision, but I'm not sure about Flash Attention.

ZFTurbo commented 9 months ago

how many GPUs do you have?

I usually train on 1 GPU )

lucidrains commented 9 months ago

def turn on flash attention if you have right hardware

it is some flag on init

turian commented 7 months ago

Wow! Reversible networks would be cool

@ZFTurbo Are we confident that MelRoFormer trained upon stereo stems? (If not, that would halve their effective batch size, if they trained mono stems.) Perhaps they didn't because: a) Bandsplit-RNN didn't either, as far as I know and b) training individual models per stem suggests that, like Bandsplit-RNN, they didn't add multistem/multichannel functionality like @lucidrains did.

HTDemucs and variants, on the other hand, trained all stems simultaneously, and then fine-tuned on individual stems. This leads to models 4x the size, but I am a little surprised the authors didn't possibly include this easy win, given how many GPUs they used to train.

deyituo commented 7 months ago

@ZFTurbo how do you collect the songs?

carlosalberto2000 commented 6 months ago

@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.

ZFTurbo commented 6 months ago

@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.

I posted some weights here: https://github.com/ZFTurbo/Music-Source-Separation-Training/tree/main?tab=readme-ov-file#vocal-models

But SDR metric is not really great

carlosalberto2000 commented 6 months ago

Thank you a lot! I really appreciate your work. So, the crown still goes to "MDX23C for vocals + HTDemucs4 FT" currently, right? How do you think they achieved a score so high in the MDX'23?

jarredou commented 6 months ago

How do you think they achieved a score so high in the MDX'23?

Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.

carlosalberto2000 commented 6 months ago

How do you think they achieved a score so high in the MDX'23?

Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.

Oh, I see. I got caught up with the names. Are you aware of any trained BS-RoFormer model?