audeering / audtorch

Utils and data sets for audio and PyTorch
https://audeering.github.io/audtorch/
Other
83 stars 9 forks source link

Invertible transforms #37

Open ATriantafyllopoulos opened 5 years ago

ATriantafyllopoulos commented 5 years ago

Feature

I would like to introduce invertible transforms. This means that every transform that can be inverted will have an extra function, e.g. named inverse that would undo the operation it did on the previous signal. An example for our Normalize transform would look like this:


class Normalize(object):
    def __init__(self, *, axis=-1):
        super().__init__()
        self.axis = axis
        self.peak = None

    def __call__(self, signal):
        if self.axis is not None:
            self.peak = np.expand_dims(np.amax(np.abs(signal), axis=self.axis),
                                       axis=self.axis)
        else:
            self.peak = np.amax(np.abs(signal))
        return signal / np.maximum(self.peak, 1e-7)

    def inverse(self, signal):
        assert self.peak is not None
        return signal * np.maximum(self.peak, 1e-7)

Motivation

For specific use-cases, it would be nice to make our transforms invertible. The ones I have in mind are those were some kind of reconstruction is required after an architecture has processed the signal (e.g. denoising or source separation in the spectrogram domain).

It might be a limited use-cases, but our API is currently not supporting it.

Problems

Any opinions on this? Am I the only one who needs it?

phtephanx commented 5 years ago

For deployment purposes it could be useful.

Perhaps we could change the nature of the functional's return to a state_dict

I didn't quite understand this proposed interface. Could you draft a function corpse, e.g. for Normalize?

hagenw commented 5 years ago

For Spectrogram would it look like this?

class Spectrogram(object):

    def __init__(self, window_size, hop_size, *, fft_size=None,
                 window='hann', axis=-1):
        super().__init__()
        self.window_size = window_size
        self.hop_size = hop_size
        self.fft_size = fft_size
        self.window = window
        self.axis = axis
        self.phase = []

    def __call__(self, signal):
        self.spectrogram = F.stft(signal, self.window_size, self.hop_size,
                                  fft_size=self.fft_size, window=self.window,
                                  axis=self.axis)
        magnitude, _ = librosa.magphase(self.spectrogram)
        return magnitude

    def inverse(self):
        return F.istft(self.spectrogram, self.window_size, self.hop_size,
                       window=self.window, axis=self.axis)

As this is only limited to the last processed signal, I also see no need for changing the functionals. Could you provide an example for that.

ATriantafyllopoulos commented 5 years ago

This one was easy. In Standardize for example you would have to do:

class Standardize(object):
    def __init__(self, *, mean=True, std=True, axis=-1):
        super().__init__()
        self.axis = axis
        self.mean = mean
        self.std = std

    def __call__(self, signal):
        if self.mean:
            signal_mean = np.mean(signal, axis=self.axis)
            if self.axis is not None:
                self.signal_mean = np.expand_dims(signal_mean, axis=self.axis)
            signal = signal - self.signal_mean
        if self.std:
            self.signal_std = np.std(signal, axis=self.axis)
            if self.axis is not None:
                self.signal_std = np.expand_dims(
                    self.signal_std, axis=self.axis)
            signal = signal / np.maximum(self.signal_std, 1e-7)
        return signal

    def inverse(self, signal):
        if self.std:
            signal = signal * np.maximum(self.signal_std, 1e-7)
        if self.mean:
            signal = signal + self.signal_mean
        return signal

because you somehow need the internal variables mean and std.

hagenw commented 5 years ago

OK, I see in this case it would indeed be the easiest to remove F.standardize to solve the issue.

Could you maybe compile a list, which functionals would be affected by this? If it would affect only few of them we could think about removing those.