bytedance / music_source_separation

Other
1.26k stars 194 forks source link

stft and istft are placed outside forward #68

Open Blakey-Gavin opened 7 months ago

Blakey-Gavin commented 7 months ago

Hello, I made a few modifications, mainly putting stft and istft outside the forward function. The following is the modification of the "train_step" function in the "lightning_modules.py" script.

` def training_step(self, batch_data: Tuple, batch_idx: int) -> torch.float: input_data, target_data = self.batch_data_preprocessor(batch_data) batch_size, channels_num, segment_samples = input_data.shape

    input = input_data.reshape(batch_size * channels_num, segment_samples)

    input_spec = torch.stft(input, self.window_length, self.hop_length,
                            window=torch.hann_window(self.window_length).cuda(), onesided=self.onesided)

    output_spec = self.model(input_spec)

    output = torch.istft(output_spec, self.window_length, self.hop_length,
                         window=torch.hann_window(self.window_length).cuda(),
                         onesided=self.onesided, length=segment_samples)

    output = output.reshape(batch_size, channels_num, segment_samples)

    # Calculate loss.
    loss = self.loss_function(
        output=output,
        target=target_data,
        mixture=input_data
    )
    return loss

`

Of course, corresponding modifications have been made in "separator.py" and other scripts, but when I trained, all SDR values obtained by evaluate were 0. as follows: image

I've searched for a long time but can't find the reason. Hope to get some knowledge from you, thank you.