csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
717 stars 66 forks source link

CUDA not working? #17

Closed turian closed 3 years ago

turian commented 3 years ago

I'm generating random MSSTFT classes.

I have tensors on cuda and do model.cuda. However, I get the following error:

These are the params I passed to this MSSTFT:

{'fft_sizes': [512, 64, 16, 256, 16384],
 'hop_sizes': [316, 17, 10, 135, 12809],
 'scale': None,
 'scale_invariance': False,
 'w_mag': 1.762851192526309,
 'w_phs': 0.4449890262755162,
 'w_sc': 0.0,
 'win_lengths': [256, 64, 4, 256, 16384],
 'window': 'hamming_window'}
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-59-8179a07b1db0> in <module>()
      1 for x1 in x:
      2   model.cuda()
----> 3   z = model(x1.view(1, 1, -1), x)

3 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/usr/local/lib/python3.7/dist-packages/auraloss/freq.py in forward(self, x, y)
    260         mrstft_loss = 0.0
    261         for f in self.stft_losses:
--> 262             mrstft_loss += f(x, y)
    263         mrstft_loss /= len(self.stft_losses)
    264         return mrstft_loss

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    887             result = self._slow_forward(*input, **kwargs)
    888         else:
--> 889             result = self.forward(*input, **kwargs)
    890         for hook in itertools.chain(
    891                 _global_forward_hooks.values(),

/usr/local/lib/python3.7/dist-packages/auraloss/freq.py in forward(self, x, y)
    141         # apply relevant transforms
    142         if self.scale is not None:
--> 143             x_mag = torch.matmul(self.fb, x_mag)
    144             y_mag = torch.matmul(self.fb, y_mag)
    145 

RuntimeError: Tensor for 'out' is on CPU, Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for baddbmm)
csteinmetz1 commented 3 years ago

Hey @turian, I have an idea of the issue. Computation of the loss on GPU should work. To me, it looks like this is happening because of the Mel filterbanks (self.fb) not being on GPU when using a MelSTFTLoss. This is a current known issue, and the current solution is to manually move the filterbanks to the correct device before hand. This will be fixed in the next auraloss release.

I have created an example that runs on CPU and GPU for me. This will demonstrate what I am talking about and the current workaround.

It seems to me that the parameters shown in your examples are not the ones causing the issue, since this issue only appears for me when I create a loss that has scale="mel". I have created in this example, one normal STFT loss as well as one that uses the Mel scaling. If you do not move the self.fb tensors to the correct device manually then I get the same error as you. To address this, below I show move to loop over the STFTLoss objects and move the filterbanks to the correct device before computing the loss, which works for me, and produces the same results as on CPU.

import torch
import auraloss

params = {'fft_sizes': [1024],
 'hop_sizes': [512],
 'scale': None,
 'scale_invariance': False,
 'n_bins': None,
 'w_mag': 1.762851192526309,
 'w_phs': 0.0,
 'w_sc': 0.0,
 'win_lengths': [1024],
 'window': 'hamming_window',
 'sample_rate': 44100}

melparams = {'fft_sizes': [1024],
 'hop_sizes': [512],
 'scale': "mel",
 'scale_invariance': False,
 'n_bins': 64,
 'w_mag': 1.762851192526309,
 'w_phs': 0.0,
 'w_sc': 0.0,
 'win_lengths': [1024],
 'window': 'hamming_window',
 'sample_rate': 44100}

# standard MRSTFT
mrstft = auraloss.freq.MultiResolutionSTFTLoss(**params)

# use mel STFTs
melmrstft = auraloss.freq.MultiResolutionSTFTLoss(**melparams)

x = (torch.rand(1,1,44100) * 2) - 1
y = (torch.rand(1,1,44100) * 2) - 1

# first compute the loss just on CPU tensors
# both work fine
mrstft_loss = mrstft(x, y)
melmrstft_loss = melmrstft(x, y)
print("cpu mrstft: ", mrstft_loss)
print("cpu melmrstft: ", melmrstft_loss)

# -------- GPU ----------

# move data to GPU
x = x.to("cuda:0")
y = y.to("cuda:0")

# compute loss on GPU
mrstft_loss = mrstft(x, y)
print("gpu mrstft: ", mrstft_loss)

# move MelSTFT loss filterbanks to GPU
# this will be done automatically in future auraloss version
for stft_loss in melmrstft.stft_losses:
    stft_loss.fb = stft_loss.fb.to("cuda:0")

# compute loss on GPU
melmrstft_loss = melmrstft(x, y)
print("gpu melmrstft: ", melmrstft_loss)

The output:

python test.py 
cpu mrstft:  tensor(1.2394)
cpu melmrstft:  tensor(0.5864)
gpu mrstft:  tensor(1.2394, device='cuda:0')
cpu melmrstft:  tensor(0.5864, device='cuda:0')

Let me know if this work around works for you. Also, note that the w_phs parameter will have no effect at the moment since the phase loss term is currently not implemented. It has some issues, but now with PyTorch 1.8 they will be supported, but I have not moved over to the latest version yet.