KinWaiCheuk / nnAudio

Audio processing by using pytorch 1D convolution network
MIT License
1.04k stars 89 forks source link

Gammatone Filterbank waveform outputs #128

Open astrocyted opened 1 year ago

astrocyted commented 1 year ago

Hi, I'm interested to have an nn.module gammatone Filterbank that produces the filtered outputs directly [N_filters X Signal length], would it be possible to achieve it within your framwork and without having to go through the loop of number of filters?

KinWaiCheuk commented 1 year ago

Isn't Gammatonegram already like this? The output is already [Batch, N_filters X Signal length]. Or am I understanding your question wrongly?

astrocyted commented 1 year ago

Isn't Gammatonegram already like this? The output is already [Batch, N_filters X Signal length]. Or am I understanding your question wrongly?

Just saw your reply now. No. clearly thats not what Gammatonegram returns. check the docs of yourt code: Returns

spectrogram : torch.tensor
    It returns a tensor of spectrograms.  shape = ``(num_samples, freq_bins,time_steps)``. 

time_steps is not signal length, but rather signal_length/frame_hop, i want the per channel IIR filtered waveform not the binned fft

KinWaiCheuk commented 1 year ago

I understand your question now. I am not familiar with gammatone and gammatonegram. This feature is implemented by @WangHelin1997. Maybe he can comment more on it?

Alternatively, can you recommend me any python library that could produce the filtered waveforms? I will check if I could implement it under the current nnAudio framework. It would be a great help if I have something to refer to just to check if I could implement it correctly.

astrocyted commented 1 year ago

https://github.com/detly/gammatone/blob/master/gammatone/filters.py

This is one example of its implementation. the output of erb_filterbank() function is what im asking for. its quite slow though. I tried to do it in torch too myself but not really sped up:

class GammatoneFilterbank(torch.nn.Module):
    def __init__(self,
                num_filters=64,
                sample_rate=16000,
                fmin= 50,
                fmax = None,
                gtgram = False,
                frame_length = 400,
                hop_length= 160    
                ):

        super(GammatoneFilterbank, self).__init__()
        self.num_filters = num_filters
        self.sample_rate = sample_rate

        self.gtgram = gtgram
        self.frame_length = frame_length
        self.hop_length = hop_length

        self.fmin = fmin
        if fmax:
            self.fmax = fmax
        else:
            self.fmax = self.sample_rate/2

        self.centre_freqs = self.centre_frequencies()
        self.filter_coefs = self.make_erb_filters()

    @staticmethod
    def erb_point(low_freq, high_freq, fraction):
        ear_q = 9.26449  # Glasberg and Moore Parameters
        min_bw = 24.7
        order = 1

        low_freq = torch.tensor(low_freq)
        high_freq = torch.tensor(high_freq)

        erb_point = (
            -ear_q * min_bw
            + torch.exp(
                fraction * (
                    -torch.log(high_freq + ear_q * min_bw)
                    + torch.log(low_freq + ear_q * min_bw)
                )
            ) *
            (high_freq + ear_q * min_bw)
        )

        return erb_point

    @staticmethod
    def erb_space(
        low_freq=50,
        high_freq=8000,
        num_bands=64):
        """
        This function computes an array of ``num`` frequencies uniformly spaced
        between ``high_freq`` and ``low_freq`` on an ERB scale.

        For a definition of ERB, see Moore, B. C. J., and Glasberg, B. R. (1983).
        "Suggested formulae for calculating auditory-filter bandwidths and
        excitation patterns," J. Acoust. Soc. Am. 74, 750-753.
        """
        return GammatoneFilterbank.erb_point(
            low_freq,
            high_freq,
            torch.arange(1, num_bands + 1) / num_bands
            )

    def centre_frequencies(self):
        """
        Calculates an array of centre frequencies (for :func:`make_erb_filters`)
        from a sampling frequency, lower cutoff frequency and the desired number of
        filters.

        :param fs: sampling rate
        :param num_freqs: number of centre frequencies to calculate
        :type num_freqs: int
        :param cutoff: lower cutoff frequency
        :return: same as :func:`erb_space`
        """
        return GammatoneFilterbank.erb_space(low_freq= self.fmin, high_freq= self.fmax, num_bands=self.num_filters)

    def make_erb_filters(self, width=1.0):
        T = 1 / self.sample_rate
        ear_q = 9.26449 # Glasberg and Moore Parameters
        min_bw = 24.7
        order = 1

        if not torch.is_tensor(self.centre_freqs):
            self.centre_freqs = torch.Tensor(self.centre_freqs)

        erb = width*((self.centre_freqs / ear_q) ** order + min_bw ** order) ** (1 / order)
        B = 1.019 * 2 * torch.Tensor([math.pi]) * erb

        arg = 2 * self.centre_freqs * torch.Tensor([math.pi]) * T
        vec = torch.exp(2j * arg)

        A0 = T
        A2 = 0
        B0 = 1
        B1 = -2 * torch.cos(arg) / torch.exp(B * T)
        B2 = torch.exp(-2 * B * T)

        rt_pos = torch.sqrt(torch.tensor(3 + 2 ** 1.5))
        rt_neg = torch.sqrt(torch.tensor(3 - 2 ** 1.5))

        common = -T * torch.exp(-(B * T))

        k11 = torch.cos(arg) + rt_pos * torch.sin(arg)
        k12 = torch.cos(arg) - rt_pos * torch.sin(arg)
        k13 = torch.cos(arg) + rt_neg * torch.sin(arg)
        k14 = torch.cos(arg) - rt_neg * torch.sin(arg)

        A11 = common * k11
        A12 = common * k12
        A13 = common * k13
        A14 = common * k14

        gain_arg = torch.exp(1j * arg - B * T)

        gain = torch.abs(
            (vec - gain_arg * k11)
            * (vec - gain_arg * k12)
            * (vec - gain_arg * k13)
            * (vec - gain_arg * k14)
            * (T * torch.exp(B * T)
                / (-1 / torch.exp(B * T) + 1 + vec * (1 - torch.exp(B * T)))
            )**4
        )

        allfilts = torch.ones_like(self.centre_freqs)

        fcoefs = torch.stack([
            A0 * allfilts, A11, A12, A13, A14, A2*allfilts,
            B0 * allfilts, B1, B2,
            gain
        ], dim=1)

        return fcoefs

    def erb_filterbank(self, waveform):
        #Batch x Time
        if waveform.ndim==1:
            waveform = waveform[None,:]

        #output = torch.zeros((self.filter_coefs[:,9].shape[0], waveform.shape[-1]))

        gain = self.filter_coefs[:, 9]
        # A0, A11, A2
        As1 = self.filter_coefs[:, (0, 1, 5)]
        # A0, A12, A2
        As2 = self.filter_coefs[:, (0, 2, 5)]
        # A0, A13, A2
        As3 = self.filter_coefs[:, (0, 3, 5)]
        # A0, A14, A2
        As4 = self.filter_coefs[:, (0, 4, 5)]
        # B0, B1, B2
        Bs = self.filter_coefs[:, 6:9]

        stacked_waveforms = waveform.expand(self.filter_coefs.shape[0],*waveform.shape[1:])

        y1 = F.lfilter(stacked_waveforms, Bs, As1, clamp=False)
        y2 = F.lfilter(y1, Bs, As2,clamp=False)
        y3 = F.lfilter(y2, Bs, As3,clamp=False)
        y4 = F.lfilter(y3, Bs, As4,clamp=False)

        return y4 / gain.unsqueeze(-1)

    def forward(self, x):
        if self.gtgram:
            x = self.erb_filterbank(x)
            x = torch.nn.functional.pad(x,(self.frame_length//2, self.frame_length - self.frame_length//2))
            x = torch.sum(x.unfold(-1, self.frame_length, self.hop_length)**2, axis=-1)
            return torch.sqrt(x)
        else:
            return self.erb_filterbank(x)

I guess the fastest ones are the ones directly written in C.