csteinmetz1 / auraloss

Collection of audio-focused loss functions in PyTorch
Apache License 2.0
695 stars 66 forks source link

MelSTFTLoss does not correctly register filterbanks as buffers #34

Open csteinmetz1 opened 2 years ago

csteinmetz1 commented 2 years ago
# setup mel filterbank
if self.scale == "mel":
    assert sample_rate != None  # Must set sample rate to use mel scale
    assert n_bins <= fft_size  # Must be more FFT bins than Mel bins
    fb = librosa.filters.mel(sample_rate, fft_size, n_mels=n_bins)
    self.fb = torch.tensor(fb).unsqueeze(0)
elif self.scale == "chroma":
    assert sample_rate != None  # Must set sample rate to use chroma scale
    assert n_bins <= fft_size  # Must be more FFT bins than chroma bins
    fb = librosa.filters.chroma(sample_rate, fft_size, n_chroma=n_bins)
    self.fb = torch.tensor(fb).unsqueeze(0)

if scale is not None and device is not None:
    self.fb = self.fb.to(self.device)  # move filterbank to device

This causes an issue when trying to compute the loss term on GPU as the self.fb object will not get moved automatically to the correct device. This is simple to resolve and should only require registering the filterbank object as a buffer.

Something like self.register_buffer("fb", fb)