adobe-research / convmelspec

Convmelspec: Convertible Melspectrograms via 1D Convolutions
Apache License 2.0
131 stars 9 forks source link

DFT and Torchaudio mode #11

Closed hbellafkir closed 4 months ago

hbellafkir commented 5 months ago

Switching between torchaudio and DFT mode results in different results. To reproduce:

FFT_SIZE = 1024
SR = 16_000
HOP_SIZE = 512
MEL_BANDS = 64

x = torch.rand(1, SR)

wn = sig.windows.hann(FFT_SIZE, sym=True)

stft = Spectrogram(
    sr=SR,
    n_fft=FFT_SIZE,
    hop_size=HOP_SIZE,
    n_mel=None,
    padding=0,
    window=wn,
    spec_mode="DFT",
    dtype=torch.float32,
)

stft_ta = Spectrogram(
    sr=SR,
    n_fft=FFT_SIZE,
    hop_size=HOP_SIZE,
    n_mel=None,
    padding=0,
    window=wn,
    spec_mode="torchaudio",
    dtype=torch.float32,
)

# AssertionError
assert torch.allclose(stft(x), stft_ta(x), atol=1e-2)

Training in Torchaudio mode and exporting in DFT mode is not a good idea in this case. Am I doing something wrong?

urinieto commented 4 months ago

There are a couple of issues in your script:

1) convmelspec produces similar spectrograms but not identical ones, especially in higher frequencies. If you pass random noise as input (i.e., torch.rand(1, SR)), the differences in the spectrograms will be high. This shouldn't be a problem with most audio signals.

2) If you want to use the compatible torchaudio mode, please use the Torchaudio Spectrogram class directly.

Here's an updated version of your script that runs fine in my M3 MacBook Pro (note that the spectrograms are really close, since I set atol=1e-5):

import librosa
import numpy as np
import torch
import torchaudio
from scipy import signal as sig

from convmelspec.stft import ConvertibleSpectrogram as Spectrogram

FFT_SIZE = 1024
SR = 16_000
HOP_SIZE = 512
MEL_BANDS = None

def get_hann_torch(win_size, sym=True):
    wn = sig.windows.hann(win_size, sym=sym).astype(np.float32)
    return torch.from_numpy(wn)

def get_audio():
    example_audio_path = librosa.example("nutcracker")
    y, sr = librosa.load(example_audio_path, sr=SR)
    total_sec = 1
    y = y[int(sr) : (total_sec * sr + sr)].astype(np.float32)
    return torch.from_numpy(y).unsqueeze(0)

def main():

    x = get_audio()
    wn = sig.windows.hann(FFT_SIZE, sym=True)

    stft = Spectrogram(
        sr=SR,
        n_fft=FFT_SIZE,
        hop_size=HOP_SIZE,
        padding=0,
        window=wn,
        spec_mode="DFT",
        dtype=torch.float32,
    )

    stft_ta = torchaudio.transforms.Spectrogram(
        n_fft=FFT_SIZE,
        hop_length=HOP_SIZE,
        window_fn=get_hann_torch,
        power=2.0,
        center=False,
    )

    # AssertionError
    assert torch.allclose(stft(x), stft_ta(x), atol=1e-5)
    print("DFT and Torchaudio's Spectrogram all close!")