lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
433 stars 16 forks source link

MelRoformer parameters from paper #26

Open ZFTurbo opened 11 months ago

ZFTurbo commented 11 months ago

I'm trying to reproduce the paper model. But I have no luck.

My current settings which gave me batch size only 2 for 48 GB memory:

  dim: 192
  depth: 8
  stereo: true
  num_stems: 1
  time_transformer_depth: 1
  freq_transformer_depth: 1
  num_bands: 60
  dim_head: 64
  heads: 8
  attn_dropout: 0.1
  ff_dropout: 0.1
  flash_attn: True
  dim_freqs_in: 1025
  sample_rate: 44100  # needed for mel filter bank from librosa
  stft_n_fft: 2048
  stft_hop_length: 512
  stft_win_length: 2048
  stft_normalized: False
  mask_estimator_depth: 2
  multi_stft_resolution_loss_weight: 1.0
  multi_stft_resolutions_window_sizes: !!python/tuple
  - 4096
  - 2048
  - 1024
  - 512
  - 256
  multi_stft_hop_size: 147
  multi_stft_normalized: False

On input I give 8 seconds of 44100Hz so length is 352800.

I run my code model through torchinfo:

from torchinfo import summary
summary(model, input_size=(1, 2, 352768))

Report is:

==============================================================================================================
Layer (type:depth-idx)                                       Output Shape              Param #
==============================================================================================================
MelBandRoformer                                              [1, 2, 352768]            56,503,768
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─BandSplit: 1-2                                             [1, 690, 60, 192]         --
│    └─ModuleList: 2                                         --                        --
....
├─ModuleList: 1-1                                            --                        --
│    └─ModuleList: 2-1                                       --                        384
│    │    └─Transformer: 3-77                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-78                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-2                                       --                        384
│    │    └─Transformer: 3-79                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-80                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-3                                       --                        384
│    │    └─Transformer: 3-81                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-82                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-4                                       --                        384
│    │    └─Transformer: 3-83                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-84                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-5                                       --                        384
│    │    └─Transformer: 3-85                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-86                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-6                                       --                        384
│    │    └─Transformer: 3-87                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-88                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-7                                       --                        384
│    │    └─Transformer: 3-89                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-90                                [690, 60, 192]            (recursive)
│    └─ModuleList: 2-8                                       --                        384
│    │    └─Transformer: 3-91                                [60, 690, 192]            (recursive)
│    │    └─Transformer: 3-92                                [690, 60, 192]            (recursive)
├─ModuleList: 1                                              --                        --
│    └─MaskEstimator: 2-9                                    [1, 690, 7916]            --
==============================================================================================================
Total params: 69,102,468
Trainable params: 69,102,404
Non-trainable params: 64
Total mult-adds (G): 8.35
==============================================================================================================
Input size (MB): 2.82
Forward/backward pass size (MB): 703.40
Params size (MB): 232.17
Estimated Total Size (MB): 938.40
==============================================================================================================

From report I expect to have batch more than 48. But in the end I can use batch only 2.

GPU memory usage for batch = 2: изображение

To follow the paper I must increase dim to 384, depth to 12 and decrease stft_hop_length to 441 - to be 10 ms. In this case batch size will be only 1 or not fit in memory )

Any ideas how to deal with such big memory usage?

lucidrains commented 11 months ago

there's a bunch of techniques

gradient accumulation is what you can immediately use, but also look into gradient checkpointing etc

ZFTurbo commented 11 months ago

I use gradient accumulation now. But in paper authors report model fit in 32 GB with batch size = 6.

lucidrains commented 11 months ago

how many GPUs do you have?

lucidrains commented 11 months ago

ah, well I can't help you with that, but this is a common issue and you aught to be able to figure it out with the resources online

lucidrains commented 11 months ago

the biggest gun I can bring out would be reversible networks, which could work for this architecture. maybe at a later date. for now, maybe accumulate gradients and wait it out?

lucidrains commented 11 months ago

are you using mixed precision with flash attention turned on?

ZFTurbo commented 11 months ago

I use mixed precision, but I'm not sure about Flash Attention.

ZFTurbo commented 11 months ago

how many GPUs do you have?

I usually train on 1 GPU )

lucidrains commented 11 months ago

def turn on flash attention if you have right hardware

it is some flag on init

turian commented 10 months ago

Wow! Reversible networks would be cool

@ZFTurbo Are we confident that MelRoFormer trained upon stereo stems? (If not, that would halve their effective batch size, if they trained mono stems.) Perhaps they didn't because: a) Bandsplit-RNN didn't either, as far as I know and b) training individual models per stem suggests that, like Bandsplit-RNN, they didn't add multistem/multichannel functionality like @lucidrains did.

HTDemucs and variants, on the other hand, trained all stems simultaneously, and then fine-tuned on individual stems. This leads to models 4x the size, but I am a little surprised the authors didn't possibly include this easy win, given how many GPUs they used to train.

deyituo commented 10 months ago

@ZFTurbo how do you collect the songs?

carlosalberto2000 commented 9 months ago

@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.

ZFTurbo commented 9 months ago

@ZFTurbo Do you intend to release a trained model? Looking forward to test this approach against Demucs, but I don't have the means or knowledge to train a model.

I posted some weights here: https://github.com/ZFTurbo/Music-Source-Separation-Training/tree/main?tab=readme-ov-file#vocal-models

But SDR metric is not really great

carlosalberto2000 commented 9 months ago

Thank you a lot! I really appreciate your work. So, the crown still goes to "MDX23C for vocals + HTDemucs4 FT" currently, right? How do you think they achieved a score so high in the MDX'23?

jarredou commented 9 months ago

How do you think they achieved a score so high in the MDX'23?

Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.

carlosalberto2000 commented 9 months ago

How do you think they achieved a score so high in the MDX'23?

Only BS-Roformer was evaluated during MDX23 by ByteDance-SAMI, Mel-Roformer came after the contest.

Oh, I see. I got caught up with the names. Are you aware of any trained BS-RoFormer model?