pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.54k stars 652 forks source link

Parallelization of torchaudio.functional operators #919

Open adrienchaton opened 4 years ago

adrienchaton commented 4 years ago

Hello,

Is there any way of using the torchaudio.functional operators in a parallelized way, for instance considering torchaudio.functional.lowpassbiquad: waveform [batch,length] _ cutoff_freq [batch,1] And have a whole minibatch of waveforms low-pass filtered with their individual cutoff frequency.

Right now I use some operations derived from the DDSP library, for example making batches of sinc impulse responses and applying them by FFT convolution. And for instance I have the cutoff_freq predicted by a model, so it runs in parallel for fast training.

Although the quality of this filtering technique is not so good, it seems the lowpass_biquad does much better.

Thanks !

vincentqb commented 4 years ago

The pytorch DataLoader loads and transforms data in async in parallel, e.g. here. The num_workers option specifies the number of workers that can load the data (the default case 0 means main process loads the data and wait).

If you combine that with transforms as part of the dataset, then the workers are also applying transforms/functionals in parallel, e.g. here. Does that help?

adrienchaton commented 4 years ago

Thank you for the tip.

Although, it is not for use as a preprocessing function. It is for using at the network output, with filtering parameters such as cutoff_freq being predicted by the model and used in the model generation process.

On a side note, the biquad filters are pretty clean on audio with enough length but on short windows (e.g. 1024 samples or less) they create artifacts, as well as the spectral filtering does. Right now, for the purpose of filtering these signal windows, it seems giving decent results to perform spectral filtering with overlap add.

Let's say, I have an minibatch of audio [bs,n_wins x win_size] and cutoff_freq [bs,n_wins] (varying over time). I filter non-overlapping windows [bs x n_wins,win_size] with cutoff_freq [bs x n_wins,1]. And I filter the (n_wins-1) in between windows (hop=win_size//2) with the interpolated cutoff_freq. Then I window each block and overlap-add.

One last question, if I would want to do a 6th order IIR filter. Would it make sense to cascade 3 times the same biquad filter as implemented in torchaudio ? Or is there some required adaptation in between the biquad parameters ?

Application would be to have the same filter cutoff but steeper in the off frequencies. It does it more or less, but I noticed that the waveform shape gets distorted (not symmetrical between positive and negative amplitudes).

vincentqb commented 4 years ago

Thank you for the tip.

Thanks for being active in the community :)

On a side note, the biquad filters are pretty clean on audio with enough length but on short windows (e.g. 1024 samples or less) they create artifacts, as well as the spectral filtering does. Right now, for the purpose of filtering these signal windows, it seems giving decent results to perform spectral filtering with overlap add.

Let's say, I have an minibatch of audio [bs,n_wins x win_size] and cutoff_freq [bs,n_wins] (varying over time). I filter non-overlapping windows [bs x n_wins,win_size] with cutoff_freq [bs x n_wins,1]. And I filter the (n_wins-1) in between windows (hop=win_size//2) with the interpolated cutoff_freq. Then I window each block and overlap-add.

Could you open a separate issue for this? It would also be useful to have a short example code so that we can reproduce on our side. In particular, we'd like to see if this is also occurring with sox. Can you provide such code?

One last question, if I would want to do a 6th order IIR filter. Would it make sense to cascade 3 times the same biquad filter as implemented in torchaudio ? Or is there some required adaptation in between the biquad parameters ?

Application would be to have the same filter cutoff but steeper in the off frequencies. It does it more or less, but I noticed that the waveform shape gets distorted (not symmetrical between positive and negative amplitudes).

Again, it would be great to have a separate issue for this with a sample code so that we can reproduce precisely your issue. Would you be able to do that? Thanks!

adrienchaton commented 4 years ago

Hi Vincent,

I made a code snippet, generating a sawtooth, low-pass filtering 3 times in series on the full-length audio with torchaudio biquad.

The waveform plot shows that it's not symmetrical and that it gets emphasized filtering after filtering. And the spectrogram plot shows that indeed the cutoff appears stronger (as expected) filtering after filtering.

Then I do the same on separate signal windows by splitting the same audio into chunks of 640 samples (at 16kHz). Both the spectrogram and the waveform shows some artifacts which are not present when filtering full-length audio into one pass.

Does this answer your questions ? Do you want I open the separate issue for both observations and put in there the different code snippets ?

import numpy as np
import torch
import scipy
import librosa
import librosa.display
from matplotlib import pyplot as plt
import torchaudio

def remove_above_nyquist(frequency_envelopes,amplitude_envelopes,sample_rate):
    """
    frequency_envelopes: [batch_size, n_samples, n_sinusoids] (>=0)
    amplitude_envelopes: [batch_size, n_samples, n_sinusoids] (>=0)
    """
    amplitude_envelopes = torch.where(
    torch.gt(frequency_envelopes, sample_rate / 2.0),
            torch.zeros_like(amplitude_envelopes), amplitude_envelopes)
    # note: should be greater or equal
    return amplitude_envelopes

def customwave_synth(f0_envelope,amplitude_envelope,sample_rate,overtones,mode,duty=0.2):
    """
    f0_envelope: [batch_size, n_samples, 1] (>=0)
    amplitude_envelope: [batch_size, n_samples, 1] (>=0)
    sawtooth = all overtones up to Nyquist with linear decay amplitude
    square = all odd overtones up to Nyquist with linear decay amplitude
    pulse = fourier expansion with duty cycle https://en.wikipedia.org/wiki/Pulse_wave
    e.g. f0=20Hz *400 (overtones) = 8000Hz (Nyquist)
    f0 should be pre-scaled in range [20,1200]
    """
    bs = f0_envelope.shape[0]
    n_overtones = overtones.shape[-1]
    frequency_envelopes = f0_envelope.expand(bs,-1,n_overtones)*overtones
    if mode=='sawtooth' or mode=='square' or mode=='pulse':
        amplitude_envelope = amplitude_envelope.expand(bs,-1,n_overtones)/overtones
    # Don't exceed Nyquist.
    amplitude_envelopes = remove_above_nyquist(frequency_envelopes,amplitude_envelope,sample_rate)
    # Angular frequency, Hz -> radians per sample.
    omegas = frequency_envelopes * (2.0 * np.pi)  # rad / sec
    omegas = omegas / float(sample_rate)  # rad / sample
    # Accumulate phase and synthesize.
    phases = torch.cumsum(omegas, axis=1)
    if mode=='sawtooth' or mode=='square':
        wavs = torch.sin(phases)
        audio = amplitude_envelopes * wavs  # [mb, n_samples, n_sinusoids]
        audio = torch.sum(audio, axis=-1)  # [mb, n_samples]
        if mode=='sawtooth':
            audio = audio/2 # empirically it seems to give a good amplitude in [-0.9,0.9] in f0 range [40,400]
        if mode=='square':
            audio = audio
    if mode=='pulse':
        wavs = torch.cos(phases)
        audio = amplitude_envelopes * wavs  # [mb, n_samples, n_sinusoids]
        audio = audio * 2.0 / np.pi
        audio = audio * torch.sin(duty*np.pi*overtones)
        audio = torch.sum(audio, axis=-1)+duty
    return audio

sample_rate = 16000
bs = 100
window_size = 640
length = bs*window_size
n_wins = length//window_size
fc = 2000.
Q = 1.

f0_min = 1000.
f0_max = 5000.
n_overtones = 400
f0 = np.linspace(f0_min,f0_max,num=length)
f0_envelope = torch.from_numpy(f0).unsqueeze(0).unsqueeze(-1).float()
amplitude_envelope = torch.ones(1,length,1)

overtones_saw = torch.arange(1,n_overtones+1).reshape(1,1,-1).float()
audio = customwave_synth(f0_envelope,amplitude_envelope,sample_rate,overtones_saw,'sawtooth')

## source audio
plt.figure()
plt.suptitle('source audio')
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(audio.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(audio.view(-1).numpy())
plt.show()

## filtering audio in full-length
filtered_audio_1 = torchaudio.functional.lowpass_biquad(audio,sample_rate, fc, Q)
filtered_audio_2 = torchaudio.functional.lowpass_biquad(filtered_audio_1,sample_rate, fc, Q)
filtered_audio_3 = torchaudio.functional.lowpass_biquad(filtered_audio_2,sample_rate, fc, Q)

plt.figure()
plt.suptitle('BIQUAD filtering LPF at '+str(fc))
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(filtered_audio_1.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(filtered_audio_1.view(-1).numpy())
plt.show()

plt.figure()
plt.suptitle('BIQUADx2 filtering LPF at '+str(fc))
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(filtered_audio_2.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(filtered_audio_2.view(-1).numpy())
plt.show()

plt.figure()
plt.suptitle('BIQUADx3 filtering LPF at '+str(fc))
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(filtered_audio_3.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(filtered_audio_3.view(-1).numpy())
plt.show()

## filtering audio chunks of window_size
audio = audio.view(bs,window_size)

filtered_audio_1 = torchaudio.functional.lowpass_biquad(audio,sample_rate, fc, Q)
filtered_audio_2 = torchaudio.functional.lowpass_biquad(filtered_audio_1,sample_rate, fc, Q)
filtered_audio_3 = torchaudio.functional.lowpass_biquad(filtered_audio_2,sample_rate, fc, Q)

plt.figure()
plt.suptitle('BIQUAD filtering LPF at '+str(fc))
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(filtered_audio_1.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(filtered_audio_1.view(-1).numpy())
plt.show()

plt.figure()
plt.suptitle('BIQUADx2 filtering LPF at '+str(fc))
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(filtered_audio_2.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(filtered_audio_2.view(-1).numpy())
plt.show()

plt.figure()
plt.suptitle('BIQUADx3 filtering LPF at '+str(fc))
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(filtered_audio_3.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(filtered_audio_3.view(-1).numpy())
plt.show()
adrienchaton commented 4 years ago

if that's more convenient, you can dl the code here https://github.com/adrienchaton/misc/blob/master/check_torchaudio_biquad.py

vincentqb commented 4 years ago

Thanks a lot for the detailed code! Why do you expect the signal to be symmetrical? It looks like torchaudio does match sox. Indeed, you can replace the torchaudio transform by the following, and you should get the same results.

filtered_audio_1, _ = torchaudio.sox_effects.apply_effects_tensor(audio, sample_rate, [["lowpass", str(fc), str(Q)]])
filtered_audio_2, _ = torchaudio.sox_effects.apply_effects_tensor(filtered_audio_1, sample_rate, [["lowpass", str(fc), str(Q)]])
filtered_audio_3, _ = torchaudio.sox_effects.apply_effects_tensor(filtered_audio_2, sample_rate, [["lowpass", str(fc), str(Q)]])

These will call sox directly instead of using torchaudio, which is what torchaudio compares against.

vincentqb commented 4 years ago

As for the artifacts, they seem to come from the views in the code below.

## filtering audio chunks of window_size
audio = audio.view(bs,window_size)

filtered_audio_1 = torchaudio.functional.lowpass_biquad(audio,sample_rate, fc, Q)
filtered_audio_2 = torchaudio.functional.lowpass_biquad(filtered_audio_1,sample_rate, fc, Q)
filtered_audio_3 = torchaudio.functional.lowpass_biquad(filtered_audio_2,sample_rate, fc, Q)

plt.figure()
plt.suptitle('BIQUAD filtering LPF at '+str(fc))
plt.subplot(2,1,1)
D = librosa.amplitude_to_db(np.abs(librosa.stft(filtered_audio_1.view(-1).numpy(),n_fft=1024)),ref=np.max)
librosa.display.specshow(D, sr=sample_rate, y_axis='linear')
plt.colorbar(format='%+2.0f dB')
plt.subplot(2,1,2)
plt.plot(filtered_audio_1.view(-1).numpy())
plt.show()

There's audio.view(bs,window_size) and filtered_audio_1.view(-1).numpy() which are incompatible in the sense that the plot will straddle many batches, leading to artifacts. What was your intention here?

adrienchaton commented 4 years ago

Thank you for checking.

Regarding the artefacts from filtering in batches of audio chunks, these would be caused by phase inconsistency between filter outputs ? The input chunks are aligned in phase but the filtering pass would create phase jumps at the concatenated edges ?

For the waveform shape, if it follows the reference sox behaviour, I guess I don't have an argument about it, I just wondered if that is correct since performing filtering in the DFT domain does not affect the waveform symmetry.

vincentqb commented 4 years ago

Thanks for engaging with torchaudio :)

Regarding the artefacts from filtering in batches of audio chunks, these would be caused by phase inconsistency between filter outputs ? The input chunks are aligned in phase but the filtering pass would create phase jumps at the concatenated edges ?

Yeah, the artifacts are edge effects between each chunks of audio. You can see less of them of course by doing:

bs = 10
audio = audio.view(bs, -1)