pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.54k stars 652 forks source link

Resample kernel creation uses loops... #2414

Open xvdp opened 2 years ago

xvdp commented 2 years ago

🐛 Describe the bug

torchaudio\functional\functionalpy def _get_sinc_resample_kernel() is slower than it should be because it uses loops instead to taking advantange of torch broadcasing. A simple rewrite will make these 9x faster.

There are other issues with resampling kernels that still cause incredible slowness if the gcd of the sample rates does not reduce them. This comment does not address these.

below some sample code of the fix, and a paraphrase of the existing code. Other than the removing loops, a few observations a. using torch.float64 does nothing b. cuda is slower than cpu on both occasions so if a cuda kernel is required it should be cast at the end of kernel building c. in place operations are faster, if no grad is required, but there is no reason why grad should be applied to these kernels

  1. faster

    import math
    import torch
    import time
    def r_kernel(orig_freq=96000, new_freq=41000, lowpass_filter_width=6, rolloff=0.99, dtype=torch.float32, device=None, resampling_method = "sinc_interpolation", beta = 14.769656459379492):
    """ remove loops, if cuda, set it at end, too many assignments to benefit from cuda
    ~ 9x faster
    """
    _time = time.time()
    
    with torch.no_grad():
        gcd = math.gcd(orig_freq, new_freq)
        orig_freq //= gcd
        new_freq //= gcd
    
        _base_freq = min(orig_freq, new_freq) * rolloff
        _scale = _base_freq / orig_freq
        _one = torch.tensor(1.0, dtype=dtype)
        width = math.ceil(lowpass_filter_width/_scale)
        _idx = torch.arange(-width, width + orig_freq, dtype=dtype)[None, None].div(orig_freq)
        _t = torch.arange(0, -new_freq, -1, dtype=dtype)[:, None, None].div(new_freq) + _idx
        _t.mul_(_base_freq).clamp_(-lowpass_filter_width, lowpass_filter_width)
    
        if resampling_method == "sinc_interpolation":
            _t.mul_(torch.pi)
            _window = _t.div(2*lowpass_filter_width).cos().pow_(2.)
        else: # kaiser window
            _beta = torch.as_tensor(beta, dtype=dtype)
            _window = _beta.mul((1 - _t.div(lowpass_filter_width).pow(2.)).sqrt()).div(_beta.i0())
            _t.mul_(torch.pi)  
        kernel = torch.where(_t == 0, _one, _t.sin().div(_t))
        kernel.mul_(_window)
        kernel.mul_(_scale).to(device=device)
    
        print(f"elapsed {1e3*(time.time() - _time):.3f} ms")
        return kernel, width
  2. paraphrase of existing code
    
    import math
    import torch
    import time
    def old_kernel(orig_freq=96000, new_freq=41000, lowpass_filter_width=6, rolloff=0.99, dtype=torch.float32, device=None):
    _time = time.time()
    gcd = math.gcd(orig_freq, new_freq)
    orig_freq //= gcd
    new_freq //= gcd
    base_freq = min(orig_freq, new_freq) * rolloff
    width = math.ceil(lowpass_filter_width * orig_freq / base_freq)
    idx = torch.arange(-width, width + orig_freq, device=device, dtype=dtype)
    kernels = []
    for i in range(new_freq):
        t = ((-i / new_freq + idx / orig_freq) * base_freq).clamp(-lowpass_filter_width, lowpass_filter_width)* torch.pi
        kernel = torch.where(t == 0, torch.tensor(1.0).to(t), torch.sin(t) / t) * torch.cos(t/ lowpass_filter_width / 2) ** 2
        kernels.append(kernel)
    scale = base_freq / orig_freq
    kernels = torch.stack(kernels).view(new_freq, 1, -1).mul_(scale)
    if dtype is None:
        kernels = kernels.to(dtype=torch.float32)
    if device == 'cuda':
        torch.cuda.synchronize()
    print(f"elapsed {1e3*(time.time() - _time):.3f} ms")
    return kernels, width


### Versions

Collecting environment information...
PyTorch version: 1.13.0.dev20220526
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.22.4
Libc version: glibc-2.27

Python version: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:39:04) [GCC 10.3.0] (64-bit runtime)
Python platform: Linux-5.4.0-113-generic-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.6.55
GPU models and configuration: GPU 0: NVIDIA TITAN RTX
Nvidia driver version: 510.39.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.13.0.dev20220526
[pip3] torchaudio==0.12.0a0+b7624c6
[pip3] torchvision==0.14.0.dev20220526
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.3.1               h2bc3f7f_2  
[conda] mkl                       2021.4.0           h06a4308_640  
[conda] mkl-service               2.4.0           py310ha2c4b55_0    conda-forge
[conda] mkl_fft                   1.3.1           py310h2b4bcf5_1    conda-forge
[conda] mkl_random                1.2.2           py310h00e6091_0  
[conda] numpy                     1.22.3          py310hfa59a62_0  
[conda] numpy-base                1.22.3          py310h9585f30_0  
[conda] pytorch                   1.13.0.dev20220526 py3.10_cuda11.3_cudnn8.3.2_0    pytorch-nightly
[conda] pytorch-mutex             1.0                        cuda    pytorch-nightly
[conda] torchaudio                0.12.0a0+b7624c6           dev_0    <develop>
[conda] torchvision               0.14.0.dev20220526     py310_cu113    pytorch-nightly
nateanl commented 2 years ago

hi @xvdp, thanks for sharing. The improvement looks good to me. regarding the in-place operation, @mthrok do you think grad is a potential issue? I didn't find gradient check unit test for resample.

xvdp commented 2 years ago

@nateanl I did do an PR but didnt add any tests. a. I figured that any existing checks would cover this. b. locally i compared both versions of the resample to ensure that under various dtype, device and from to sampling rates torch.allclose(). Im not sure how to add those tests here. One difference between the sketch code up there is that in the PR i dindt set torch.no_grad(), because that seems like it should be handled equally for all operators.

As I also mention, resampling kernel build can be woefully inefficient if for instance one is resampling to a prime number - for most augmentation cases that is meaningless, so instead in my local projects I use closest fraction to rate change capping max kernel size. Alternatively, for yalls, porting this code to C may be helpful, python doesnt play nice with this kind of operation. But for now, this works.

carolineechen commented 2 years ago

@xvdp thanks for the suggestions and PR, broadcasting sounds like a good idea! I did try running your code here/on your PR and did not get the same results as the existing implementation though. Seems like CircleCI was broken when you created your PR, so let me leave some comments there as well.

For CUDA device, in prior benchmarking experiments we have seen improvements using CUDA for certain resampling rates, and moving a tensor from CPU to CUDA would take time as well, so we would need to get more comprehensive numbers on this before adding the change.

@nateanl

I didn't find gradient check unit test for resample

gradcheck for resample is in the transforms autograd test (and not functional), and it is passing both when torch.no_grad is added and not.

cc @adefossez

mthrok commented 2 years ago

hi @xvdp, thanks for sharing. The improvement looks good to me. regarding the in-place operation, @mthrok do you think grad is a potential issue? I didn't find gradient check unit test for resample.

Autograd should be fine as long as the existing tests pass. I agree with what @xvdp described about kernel being not differentiable.

c. in place operations are faster, if no grad is required, but there is no reason why grad should be applied to these kernels

adefossez commented 2 years ago

Thanks for the improvements :) Regarding grad computation I personally sur resampling on neural network outputs and in that case you must backpropagate the gradient to the network. For the issue with slowness when the gcd is large I don't see a fix using this algorithm. The resampy one is much better in that case, but also harder to parallelize.

xvdp commented 2 years ago

@adefossez larger kernels are an issue if you need exact resampling, but if you want to randomize it then one doesnt need the exact gdc... , Theres more than one way to compute an approx GCD with kernels < some reasonable max. For example

import math
import torch
def approx_kernel(from_freq, to_freq, max_kernel=320, verbose_error=True):
    """ (from_freq, to_freq) -> approx(from_freq, to_freq)
    """
    _min, _max = sorted((from_freq, to_freq))
    _upres = to_freq > from_freq
    ratio = _min/_max
    denominator =  torch.arange(2, max_kernel + 1)
    numerator = denominator * ratio
    closest = torch.fmod(numerator, 1)
    argmin = (closest- torch.round(closest)).abs().argmin()
    _min = torch.round(numerator[argmin]).long().item()
    _max = denominator[argmin].long().item()
    if verbose_error:
        _ratio = _min/_max
        _gcd = math.gcd(from_freq, to_freq)
        _ex_from = from_freq//_gcd
        _ex_to= to_freq//_gcd
        _ap_from, _ap_to = [(_max, _min),(_min, _max)][_upres]
        print(f"exact ratio  {ratio}; exact kernel {_ex_from, _ex_to}")
        print(f"approx ratio {_ratio}; approx kernel {_ap_from, _ap_to}")
        print(f"error {abs(ratio-_ratio)}")
    return[(_max, _min),(_min, _max)][_upres]

k_from, k_to = approx_kernel(96000, 44102)
# exact ratio  0.45939583333333334; exact kernel (48000, 22051)
# approx ratio 0.4594594594594595; approx kernel (37, 17)
# error 6.362612612614837e-05

Transforms like pitch shift - or others - may require resampling, are not transparent to kernel size and do not always require exact shift. If you step thru +6 -6 semitones some kernels are ridiculous, but if you were to initialize them just once, its no problem. So what I did in my augmentation set, every transform that requires resampling can be passed a superclass Resampling that caches all kernels youll probably need for a project. Incidentally I use a more primitive version of the approximate gcd there.

from typing import Optional
import math
import torch
from torch import nn
from torchaudio.transforms import Resample
class Resampling(nn.Module):
    """ caches Resample transforms to avoid repeat init
    implicitly assigns device on forward
    Args
        max_kernel      int [320] default rationale, 320 := gcd(96000, 44100)
                        None: exact # can be very costly if gcd == 1
    Example caching approximate kernel of max size 30
        R = Resampling(max_kernel=30)
        x, sr = torchaudio.load(<fname>)
        y = R(x, sr, sr//3)
    """
    def __init__(self, max_kernel: Optional[int] = 320) -> None:
        super().__init__()
        self.res = {}
        self.max_kernel = max_kernel

    def forward(self, x: torch.Tensor, fromsr: int, tosr: int, **kwargs) -> torch.Tensor:
        """ returns resampled tensor, caches Resample kernel
        Args:
            x           Tensor      (N,C,L) | (C,L) | (L,)
            fromsr      int         source sample rate
            tosr        int         destintion sample rate
        kwargs:
            max_kernel  int, None   if assigned temporarily overrides self.max_kernel
        kwargs from torchaudio.transforms.Resample
            lowpass_filter_width
            rolloff
            beta
            resampling_method
        """
        max_kernel = self.max_kernel if 'max_kernel' not in kwargs else kwargs['max_kernel']
        fromsr, tosr = self._get_gcd(fromsr, tosr, max_kernel)
        kwargs = {key:kwargs[key] for key in ['lowpass_filter_width', 'rolloff', 'beta', 'resampling_method']
                  if key in kwargs}
        if (fromsr, tosr) not in self.res:
            self.res[(fromsr, tosr)] = Resample(fromsr, tosr, **kwargs)
        return self.res[(fromsr, tosr)].to(device=x.device)(x)

    def _get_gcd(self, a: int, b: int, max_kernel: Optional[int] = 320, _step: int = 1) -> tuple:
        """ approximate gcd to max size
        Args:
            a           int
            b           int
            max_kernel  int | None
        """
        _gcd = math.gcd(a, b)
        a = int(a//_gcd)
        b = int(b//_gcd)
        if max_kernel is not None:
            while (a or b) > max_kernel:
                if b%2:
                    b = b + 2*(_step%2)-1
                elif a%2:
                    a = a + 2*(_step%2)-1
                else:
                    b = b + 2*(_step%2)-1
                return self._get_gcd(a, b, max_kernel=max_kernel, _step=_step+1)
        return a, b

Im not sure what you mean about the gradients. Current resample is differentiable only insofar as its forward(), not the kernel creation. I suppose that one could write a project to learn kernels, but that's a a different animal.

xvdp commented 2 years ago

Also, sure, i understand that the version using arrays instead of a loop to create the kernel is NOT the solution for creation of lots of huge kernels. But it is 9x faster than whats in torchaudio now. Broadcasting is the bulk of the speed improvement, not the inplace operations.

xvdp commented 2 years ago

@adefossez you do have a point, quality does matter, but both audio resampling and shifting pitch degrade in quality quite quickly with range, I will look at float64 but I dont know if it has any effect there, it seemed to me that the result gets quantized anyway. But on a wider context, other than band-limited interpolation, low pass filtering, do they do at the various mastering softwares? e.g. http://src.infinitewave.ca/ It would be cool to have mastering quality transforms

adefossez commented 2 years ago

Hey @xvdp sorry for the confusion with the gradient, I hadn't looked at the PR in details, indeed no gradient is needed for the kernel creation.

Regarding the GCD approx, indeed in some cases it could be really nice to do it for the end user, but it is also a bit of a dangerous change and non backward compatible. It would also require a more fine grained analysis of what are acceptable levels of rounding for all possible sample rates.

Another way to potentially improve I had noticed it that comment I had left: https://github.com/pytorch/audio/blob/main/torchaudio/functional/functional.py#L1443 . I think one way to limit this is to use multiple convolutions with different kernel sizes, but then it requires testing to decide at what point splitting the kernels into multiple convolutions become beneficial. Although this would not be as efficient as rounding the sample rates.