lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

MelBand doesn't work with stereo #10

Closed ZFTurbo closed 10 months ago

ZFTurbo commented 10 months ago

Hello. I tried to run a training.

model = MelBandRoformer(
      stereo=True,
      dim=32,
      depth=1,
      attn_dropout=0.1,
      time_transformer_depth=1,
      freq_transformer_depth=1,
  )

But I got an error:

RuntimeError: split_with_sizes expects split_sizes to sum exactly to 3964 (input tensor's size at dimension -1), but got split_sizes=[24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 28, 32, 32, 32, 36, 40, 44, 44, 44, 52, 56, 60, 64, 64, 68, 76, 80, 84, 92, 100, 104, 112, 120, 124, 136, 148, 156, 164, 176, 188, 200, 216, 228, 244, 264, 280, 296, 316, 340, 364, 388, 412, 440, 472, 504]

Without stereo=True it works normally.

lucidrains commented 10 months ago

@ZFTurbo ah yea, forgot you need stereo for your kaggle competitions

try 0.2.2?

ZFTurbo commented 10 months ago

Still have a problem:

  File "\bs_roformer\mel_band_roformer.py", line 445, in forward
    masks_summed = torch.zeros_like(stft_repr).scatter_add_(2, scatter_indices, masks)
RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype
lucidrains commented 10 months ago

@ZFTurbo

this runs for me

import torch
from bs_roformer import MelBandRoformer

model = MelBandRoformer(
    dim = 32,
    depth = 1,
    stereo = True,
    time_transformer_depth = 1,
    freq_transformer_depth = 1
)

x = torch.randn(2, 2, 352800)
target = torch.randn(2, 2, 352800)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)

line 445 in the repository isn't the same as the line in your error trace

did you modify the file?

ZFTurbo commented 10 months ago

Sorry, I copy-pasted the code from browser it changed the line numbers for some reason. Now I copy file. I still have this error.

File "H:\Projects_2TB_Next\2018_02_Music_Voice_Separator\mdx_code_23_training1\nn41_mel_roformer_v1\bs_roformer\mel_band_roformer.py", line 430, in forward
    masks_summed = torch.zeros_like(stft_repr).scatter_add_(2, scatter_indices, masks)
RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype

I put before this line:

print(stft_repr.dtype, scatter_indices.dtype, masks.dtype)

I got the result:

torch.complex64 torch.int64 torch.complex32

I think it's the reason.

lucidrains commented 10 months ago

@ZFTurbo that's strange, because mine shows both to be complex64

what version of pytorch are you on?

would you like to try 0.2.3?

ZFTurbo commented 10 months ago

I'm using torch 2.0.1. I fixed the problem with:

masks = torch.view_as_complex(masks)

changed to:

masks = torch.view_as_complex(masks).type(torch.complex64)

Thank you

lucidrains commented 10 months ago

@ZFTurbo yup, that's what i did here

ok, it should be good!