lucidrains / BS-RoFormer

Implementation of Band Split Roformer, SOTA Attention network for music source separation out of ByteDance AI Labs
MIT License
384 stars 13 forks source link

Input data shape #3

Closed ZFTurbo closed 11 months ago

ZFTurbo commented 11 months ago

Hello, there is a problem with input data. You write

x = torch.randn(2, 131680)
target = torch.randn(2, 131680)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)

But in reality there must be also some batch size:

x = torch.randn(10, 2, 131680)
target = torch.randn(10, 2, 131680)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)

If I use like in bottom example I have a problem with stft.

RuntimeError: stft(torch.cuda.FloatTensor[6, 2, 133728], n_fft=2048, hop_length=512, win_length=2048, window=None, normalized=0, onesided=None, return_complex=1) : expected a 1D or 2D tensor
lucidrains commented 11 months ago

@ZFTurbo i'm using 2 as the batch size, with an implicit channel dimension of 1

ZFTurbo commented 11 months ago

Is it possible to remade for stereo input?

lucidrains commented 11 months ago

@ZFTurbo ah i see

in stereo, the channel dimension would be 2?

ZFTurbo commented 11 months ago

@ZFTurbo ah i see

in stereo, the channel dimension would be 2?

Yes )

lucidrains commented 11 months ago

ok, so i just do stft on each dimension and concat? sorry but i'm actually not a domain expert (besides the transformer part)

ZFTurbo commented 11 months ago

I think here is the simple example how it can be done: https://github.com/kuielab/sdx23/blob/mdx_AB/my_submission/src/tfc_tdf_v3.py

lucidrains commented 11 months ago

@ZFTurbo ahh got it, yea, like i guessed

yup, i can do this, and i guess the multi-stft loss can also be done per channel dimension and just summed

lucidrains commented 11 months ago

tada

try 0.1.0

import torch
from bs_roformer import BSRoformer

model = BSRoformer(
    dim = 512,
    depth = 12,
    stereo = True,
    time_transformer_depth = 1,
    freq_transformer_depth = 1,
    freqs_per_bands = (512, 513)  # in the paper, they divide into ~60 bands, test with 1 for starters
)

x = torch.randn(2, 2, 131680)
target = torch.randn(2, 2, 131680)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)
lucidrains commented 11 months ago

ok, calling it a day, happy friday!

ZFTurbo commented 11 months ago

Thank you. It works for me now.

One more proposal. For MSS task you sometimes need more than one stems on output.

so target can have shape like (BS, Stems number, 2, length)

target = torch.randn(2, 4, 2, 131680)

May be it's good idea to add to support for example 4 stems separation.

lucidrains commented 11 months ago

ok, very last commit for the day

import torch
from bs_roformer import BSRoformer

model = BSRoformer(
    dim = 512,
    depth = 12,
    stereo = True,
    num_stems = 2,
    time_transformer_depth = 1,
    freq_transformer_depth = 1,
    freqs_per_bands = (512, 513)  # in the paper, they divide into ~60 bands, test with 1 for starters
)

x = torch.randn(2, 2, 131680)
target = torch.randn(2, 2, 2, 131680)

loss = model(x, target = target)
loss.backward()

# after much training

out = model(x)
lucidrains commented 11 months ago

if that doesn't look right, i can address it next Monday

ZFTurbo commented 11 months ago

Thank you I will test it!

lucidrains commented 11 months ago

@ZFTurbo if stems is just different targets, you could also just train 4 separate models (in your example), or am i misunderstanding?

ZFTurbo commented 11 months ago

@ZFTurbo if stems is just different targets, you could also just train 4 separate models (in your example), or am i misunderstanding?

Yes, you are right. But sometimes it's better to have all in single model.

lucidrains commented 11 months ago

ohh got it, so 'stems' just another synonym for adapter heads

yea, this should work then

lucidrains commented 11 months ago

i guess you could also have a separate axial transformer block per stem head, but can leave that for another day

lucidrains commented 11 months ago

main issue has been addressed