pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.54k stars 652 forks source link

Backprop support for lfilter #704

Closed daniel-p-gonzalez closed 3 years ago

daniel-p-gonzalez commented 4 years ago

🚀 Feature

It is currently not possible to backpropagate gradients through an lfilter because of this inplace operation: https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L661

Motivation

It's not worth the pytorch overhead to even use lfilter without backprop support (it's much faster when implemented using e.g. numba). When I saw that this was implemented here, I was hoping to use it instead of my own implementation (which is implemented as a custom RNN) as it is honestly too slow.

Pitch

I would love to see that inplace operation replaced with something that would allow supporting backprop. I'm not sure what the most efficient way to do this is.

Alternatives

I implemented transposed direct form II digital filters as custom RNNs, but the performance is pretty poor (which seems to be a problem with the fuser). This is the simplest version I tried, which works, but as I said it's quite slow.

class DigitalFilterModel(jit.ScriptModule):
  def __init__(self):
    super(DigitalFilterModel, self).__init__()

  @jit.script_method
  def forward(self, x, coeffs, v1, v2, v3):
    # type: (Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]
    seq_len = x.shape[1]
    output = torch.jit.annotate(List[Tensor], [])
    x = x.unbind(1)
    coeffs = coeffs.unbind(1)
    for i in range(seq_len):
      sample = x[i]
      out = coeffs[0] * sample + v1
      output.append(out)

      v1 = coeffs[1] * sample - coeffs[4] * out + v2
      v2 = coeffs[2] * sample - coeffs[5] * out + v3
      v3 = coeffs[3] * sample - coeffs[6] * out

    return torch.stack(output, 1), v1, v2, v3

Another alternative I've used when I only need to backprop through the filter, but not optimize the actual coefficients, is to take advantage of the fact that tanh is close to linear for very small inputs and design a standard RNN to be equivalent to the digital filter. Crushing the input, then rescaling the output to keep it linear gives a result very close to the original filter, but this is obviously quite a hack:

class RNNTDFWrapper(nn.Module):
  def __init__(self, eps=0.000000001):
    super(RNNTDFWrapper, self).__init__()
    self.eps = eps
    self.rnn = nn.RNN(1, 4, 1, False, True)

  def set_coefficients(self, coeffs):
    self.rnn.weight_ih_l0.data[:,:] = torch.tensor(coeffs[:4]).view(-1,1)
    self.rnn.weight_hh_l0.data[:,:] = 0.0
    self.rnn.weight_hh_l0.data[0,1] = 1.0
    self.rnn.weight_hh_l0.data[1,2] = 1.0
    self.rnn.weight_hh_l0.data[2,3] = 1.0
    self.rnn.weight_hh_l0.data[:3,0] = -1.0 * torch.tensor(coeffs[4:])

  def forward(self, x):
    batch_size = x.shape[0]
    x = self.eps * x.view(batch_size, -1, 1)
    x, _ = self.rnn.forward(x)
    x = (1.0/self.eps) * x[:,:,0]
    return x
mthrok commented 4 years ago

Hi @FBMachine

Can you provide a snippet that demonstrates lfilter not supporting backward?

I was trying to reproduce the issue but putting lfilter into nn.Module and backprobagating through it seems to work fine, but I might be doing something wrong.

Or, are you saying that you would like to compute gradient for coefficients for lfilter?

import torch
import torchaudio.functional as F

class Net(torch.nn.Module):
    def __init__(self, a, b):
        super().__init__()
        self.a = torch.nn.Parameter(a, requires_grad=True)
        self.b = torch.nn.Parameter(b, requires_grad=True)

    def forward(self, x):
        return F.lfilter(x, self.a, self.b)

def test(device, dtype):
    net = Net(
        a=torch.tensor([0., 0., 0., 1.]),
        b=torch.tensor([1., 0., 0., 0.]),
    ).to(device=device, dtype=dtype)
    x = torch.rand(2, 8000, dtype=dtype, device=device, requires_grad=True)
    y = net(x)
    net.zero_grad()
    y.backward(torch.randn_like(y))
    print('a_grad:', net.a.grad)
    print('b_grad:', net.b.grad)
    print('x_grad:', x.grad)

for device in ['cpu', 'cuda']:
    for dtype in [torch.float32, torch.float64]:
        print(f'Running {device}, {dtype}')
        test(device, dtype)
$ python foo.py
Running cpu, torch.float32
a_grad: None
b_grad: None
x_grad: tensor([[ 2.2886,  0.6447, -0.0231,  ..., -2.0171,  0.3783,  3.4622],
        [ 0.0278, -0.1908, -0.8077,  ...,  0.9406, -0.0560, -0.6732]])
Running cpu, torch.float64
a_grad: None
b_grad: None
x_grad: tensor([[ 0.3589, -0.4542,  0.2553,  ...,  0.0147, -1.3429,  0.7961],
        [-1.2413,  1.7650, -0.3808,  ...,  1.8582, -1.2257, -0.2102]],
       dtype=torch.float64)
Running cuda, torch.float32
a_grad: None
b_grad: None
x_grad: tensor([[-0.8601,  1.1020,  0.7039,  ...,  0.7219, -0.0040, -1.4189],
        [ 1.6594, -0.5011,  1.3873,  ...,  1.1267,  0.8386,  0.9974]],
       device='cuda:0')
Running cuda, torch.float64
a_grad: None
b_grad: None
x_grad: tensor([[-0.5385, -1.4356, -0.9297,  ..., -1.2368, -0.7705, -0.5666],
        [-1.3023,  0.4728, -1.9034,  ...,  0.7344, -0.2552, -1.9788]],
       device='cuda:0', dtype=torch.float64)
vincentqb commented 4 years ago

Do we know that the derivative is correct though? A check we can do is with gradcheck.

turian commented 4 years ago

@mthrok Here is a small code example showing that you cannot backprop through an lfilter parameter:

import torch
import torchaudio
noise = torch.rand(16000)
fp = torch.tensor((440.0), requires_grad=True)
filtered_noise = torchaudio.functional.lowpass_biquad(noise, sample_rate=16000, cutoff_freq=fp)
dist = torch.mean(torch.abs(filtered_noise - noise))
dist.backward(retain_graph=False)

gives

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-9-c658d88b5d27> in <module>
----> 1 dist.backward(retain_graph=False)
      2 print(fp.grad)

/usr/local/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    183                 products. Defaults to ``False``.
    184         """
--> 185         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    186 
    187     def register_hook(self, hook):

/usr/local/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    123         retain_graph = create_graph
    124 
--> 125     Variable._execution_engine.run_backward(
    126         tensors, grad_tensors, retain_graph, create_graph,
    127         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
mthrok commented 4 years ago

Hi @turian

Thanks for the snippet. I confirm that I am seeing the same error.

I updated my previous snippet to use the lowpass_biquad and realized that even the forward function does not work.

import torch
import torchaudio.functional as F

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fp = torch.tensor((440.0), requires_grad=True)

    def forward(self, x):
        return F.lowpass_biquad(x, sample_rate=16000, cutoff_freq=self.fp)

def test(device, dtype):
    net = Net().to(device=device, dtype=dtype)
    x = torch.rand(2, 8000, dtype=dtype, device=device, requires_grad=True)
    y = net(x)
    net.zero_grad()
    y.backward(torch.randn_like(y))
    print('a_grad:', net.a.grad)
    print('b_grad:', net.b.grad)
    print('x_grad:', x.grad)

for device in ['cpu', 'cuda']:
    for dtype in [torch.float32, torch.float64]:
        print(f'Running {device}, {dtype}')
        test(device, dtype)
Traceback (most recent call last):
  File "bar.py", line 27, in <module>
    test(device, dtype)
  File "bar.py", line 17, in test
    y = net(x)
  File "/home/moto/conda/envs/PY3.8-cuda101/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "bar.py", line 11, in forward
    return F.lowpass_biquad(x, sample_rate=16000, cutoff_freq=self.fp)
  File "/scratch/moto/torchaudio/torchaudio/functional.py", line 703, in lowpass_biquad
    return biquad(waveform, b0, b1, b2, a0, a1, a2)
  File "/scratch/moto/torchaudio/torchaudio/functional.py", line 636, in biquad
    output_waveform = lfilter(
  File "/scratch/moto/torchaudio/torchaudio/functional.py", line 594, in lfilter
    o0.addmv_(windowed_output_signal, a_coeffs_flipped, alpha=-1)
RuntimeError: Output 0 of UnbindBackward is a view and is being modified inplace. This view is the output of a function that returns multiple views. Such functions do not allow the output views to be modified inplace. You should replace the inplace operation by an out-of-place one.
turian commented 3 years ago

@FBMachine what is the inplace operation in the latest master?

https://github.com/pytorch/audio/blob/5e54c770b41bbdb7b228fe511b364f3f2aa96a88/torchaudio/functional/__init__.py

Can you please copy and paste the offending lines here?

turian commented 3 years ago

@FBMachine It appears nnAudio has a lowpass filter that is differentiable:

import nnAudio.utils
import torch
from torch.nn.functional import conv1d, fold
lowpass_filter = torch.tensor(nnAudio.utils.create_lowpass_filter(
                                                    band_center = 0.5,
                                                    kernelLength=256,
                                                    transitionBandwidth=0.001
                                                    )
                             )
lowpass_filter = lowpass_filter[None,None,:]
x = torch.rand(10000)[None,None,:]
y = conv1d(x,lowpass_filter,stride=1, padding=(lowpass_filter.shape[-1]-1)//2)
yoyolicoris commented 3 years ago

Hi folks~ I also encounter this issue recently and I want to share my solution. The approach I chose is to implement a custom autograd function for lfilter.

Here's my implementation :

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter

from torch.autograd import Function, gradcheck

class lfilter(Function):

    @staticmethod
    def forward(ctx, x, a, b) -> torch.Tensor:
        with torch.no_grad():
            dummy = torch.zeros_like(a)
            dummy[0] = 1

            xh = torch_lfilter(x, a, dummy, False)

            y = xh.view(-1, 1, xh.shape[-1])
            y = F.pad(y, [b.numel() - 1, 0])
            y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)

        ctx.save_for_backward(x, a, b, xh)
        return y

    @staticmethod
    def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        x, a, b, xh = ctx.saved_tensors
        with torch.no_grad():
            dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
                           b.view(1, 1, -1)).view(*dy.shape)

            dummy = torch.zeros_like(a)
            dummy[0] = 1
            dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)

            batch = x.numel() // x.shape[-1]
            db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
                          dy.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)
            dummy[0] = -1
            dxhda = torch_lfilter(F.pad(xh, [b.numel() - 1, 0]), a, dummy, False)
            da = F.conv1d(dxhda.view(1, -1, dxhda.shape[-1]),
                          dxh.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)

        return dx, da, db

The filter form I choose is Direct-Form-II. I just wrap torchaudio.functional.lfilter inside the custom function, no extra dependency is needed.

Some comparisons between simple for-loop approach and gradient checks: https://gist.github.com/yoyololicon/f63f601d62187562070a61377cec9bf8

It has passed the gradcheck using a simple second-order filter model, and I'm planning to do more tests on higher order model.

yoyolicoris commented 3 years ago

@FBMachine does it meet your requirement?

vincentqb commented 3 years ago

Hi folks~ I also encounter this issue recently and I want to share my solution. The approach I chose is to implement a custom autograd function for lfilter.

Here's my implementation :

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.functional import lfilter as torch_lfilter

from torch.autograd import Function, gradcheck

class lfilter(Function):

    @staticmethod
    def forward(ctx, x, a, b) -> torch.Tensor:
        with torch.no_grad():
            dummy = torch.zeros_like(a)
            dummy[0] = 1

            xh = torch_lfilter(x, a, dummy, False)

            y = xh.view(-1, 1, xh.shape[-1])
            y = F.pad(y, [b.numel() - 1, 0])
            y = F.conv1d(y, b.flip(0).view(1, 1, -1)).view(*xh.shape)

        ctx.save_for_backward(x, a, b, xh)
        return y

    @staticmethod
    def backward(ctx, dy) -> (torch.Tensor, torch.Tensor, torch.Tensor):
        x, a, b, xh = ctx.saved_tensors
        with torch.no_grad():
            dxh = F.conv1d(F.pad(dy.view(-1, 1, dy.shape[-1]), [0, b.numel() - 1]),
                           b.view(1, 1, -1)).view(*dy.shape)

            dummy = torch.zeros_like(a)
            dummy[0] = 1
            dx = torch_lfilter(dxh.flip(-1), a, dummy, False).flip(-1)

            batch = x.numel() // x.shape[-1]
            db = F.conv1d(F.pad(xh.view(1, -1, xh.shape[-1]), [b.numel() - 1, 0]),
                          dy.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)
            dummy[0] = -1
            dxhda = torch_lfilter(F.pad(xh, [b.numel() - 1, 0]), a, dummy, False)
            da = F.conv1d(dxhda.view(1, -1, dxhda.shape[-1]),
                          dxh.view(-1, 1, dy.shape[-1]),
                          groups=batch).sum((0, 1)).flip(0)

        return dx, da, db

The filter form I choose is Direct-Form-II. I just wrap torchaudio.functional.lfilter inside the custom function, no extra dependency is needed.

Some comparisons between simple for-loop approach and gradient checks: https://gist.github.com/yoyololicon/f63f601d62187562070a61377cec9bf8

It has passed the gradcheck using a simple second-order filter model, and I'm planning to do more tests on higher order model.

Thanks for writing this and sharing it with the community! If torchscriptabilitiy is not a concern, then this is a great way to bind the forward and the backward pass :) This is in fact how we (temporarily) bind the prototype RNN transducer here in torchaudio.

Such custom autograd functions (both in python and C++) are not currently supported by torchscript though. Using this within torchaudio directly in place of the current lfilter (which is torchscriptable) would be BC breaking unfortunately. In the long term, we'll need to register the backward pass with autograd. Here's a tutorial for how to do this in a torchscriptable manner.

yoyolicoris commented 3 years ago

@vincentqb thanks, I'll take a look.