facebookresearch / denoiser

Real Time Speech Enhancement in the Waveform Domain (Interspeech 2020)We provide a PyTorch implementation of the paper Real Time Speech Enhancement in the Waveform Domain. In which, we present a causal speech enhancement model working on the raw waveform that runs in real-time on a laptop CPU. The proposed model is based on an encoder-decoder architecture with skip-connections. It is optimized on both time and frequency domains, using multiple loss functions. Empirical evidence shows that it is capable of removing various kinds of background noise including stationary and non-stationary noises, as well as room reverb. Additionally, we suggest a set of data augmentation techniques applied directly on the raw waveform which further improve model performance and its generalization abilities.
Other
1.64k stars 301 forks source link

STFT Loss device issues #27

Open IdoWSC opened 3 years ago

IdoWSC commented 3 years ago

Hi, When fine-tuning on a gpu machine and setting the STFT loss to true in the config file I get an error:

    solver.train()
  File "/home/wscuser/denoiser/denoiser/solver.py", line 143, in train
    train_loss = self._run_one_epoch(epoch)
  File "/home/wscuser/denoiser/denoiser/solver.py", line 50, in _run_one_epoch
    sc_loss, mag_loss = self.mrstftloss(estimate.squeeze(1), clean.squeeze(1))
  File "/anaconda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/wscuser/denoiser/denoiser/stft_loss.py", line 138, in forward
    sc_l, mag_l = f(x, y)
  File "/anaconda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/wscuser/denoiser/denoiser/stft_loss.py", line 94, in forward
    x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window)
  File "/home/wscuser/denoiser/denoiser/stft_loss.py", line 28, in stft
    x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
  File "/anaconda/lib/python3.7/site-packages/torch/functional.py", line 516, in stft
    normalized, onesided, return_complex)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Any idea why is it happening?

Thanks in advance!

adiyoss commented 3 years ago

it seems not all tensors are on the same device, some on the gpu and some on the cpu. can you please upload the command you used for launching the experiment so we can reproduce the error? also, can you please make sure the target supervision (clean signal) is also copied to the gpu?

chadHGY commented 3 years ago

Hi @IdoWSC I also encountered similar problem. I found the main reason causing this error is that both "x" and "window" should be two tensors on the same device. So possible solutions are two:

  1. Check the pytorch version (suggested by the author ?)
  2. Or, directly modify your pytorch stft module (possibly in "/anaconda/lib/python3.7/site-packages/torch/functional.py") return _VF.stft(input, n_fft, hop_length, win_length, window.to(input.device), normalized, onesided, return_complex)
adefossez commented 3 years ago

Hey @IdoWSC and @chadHGY , can you try again after pulling from master? This should be fixed now.

lhy0807 commented 3 years ago

Hey @IdoWSC and @chadHGY , can you try again after pulling from master? This should be fixed now.

Hi Alexandre, I'm trying to reproduce the results and facing the same error. My environment is Torch 1.7.1 + CUDA11.0. I have tried the method suggested by @chadHGY.

modifying to return _VF.stft(input, n_fft, hop_length, win_length, window.to(input.device), normalized, onesided, return_complex), and the error was eliminated.

ferugit commented 3 years ago

As indicated previously, doing the following in the forward method of the STFTLoss class, the problem is solved:

x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window.to(x.device))
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(x.device))