pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.43k stars 636 forks source link

torchaudio.functional.convolve with torch.complex32 and interleaved fake complex torch.bfloat16 #3739

Open pfeatherstone opened 5 months ago

pfeatherstone commented 5 months ago

🚀 The feature

I have a use case for convolving two tensors with dtypes torch.float16 or torch.bfloat16 containing interleaved complex data.

Motivation, pitch

I have a couple DSP functions being applied in a training loop using float16 or bfloat16.

Alternatives

I have the following code which works

import torch
import torch.nn.functional as F
from einops import rearrange

def convolve_cplx (
    x: torch.Tensor, # [B, T, 2]
    y: torch.Tensor  # [B, H, 2]
) :
    """ Same as torch.view_as_real(torchaudio.convolve(torch.view_as_complex()...) but works with torch.float16 and torch.bfloat16"""
    B, m    = x.shape[0], int(y.shape[1]//2)
    y       = y.flip(-2) # [B H 2]
    yr, yi  = y[...,0], y[...,1]
    x       = rearrange(x, 'b t re -> 1 (b re) t')
    w       = rearrange([yr,-yi,yi,yr], '(o re) b h -> (b o) re h', re=2)
    y       = F.conv1d(x, w, padding=m, groups=B)
    y       = rearrange(y, '1 (b re) t -> b t re', re=2)
    return y

I have a unit test:

import unittest
import time
import torchaudio.functional as aF

class TestDSP(unittest.TestCase):
    def test_conv_cplx(self):
        perf0 = 0
        perf1 = 0
        first = True
        for T in [2000, 1000, 500, 200]:
            for H in [21, 5, 9, 161]:
                x = torch.randn(32, T, 2)
                h = torch.randn(32, H, 2)
                t0 = time.perf_counter()
                y1 = torch.view_as_real(aF.convolve(torch.view_as_complex(x), torch.view_as_complex(h), mode='same'))
                t1 = time.perf_counter()
                y2 = convolve_cplx(x, h)
                t2 = time.perf_counter()
                self.assertTrue(torch.allclose(y1,y2,rtol=1e-4,atol=1e-4))
                if first:
                    first = False
                else:
                    perf0 += (t1-t0)
                    perf1 += (t2-t1)
        print("torchaudio time {:.5f} custom {:.5f} improvement {:.2f}%".format(perf0, perf1, 100*(perf0-perf1)/perf0))

if __name__ == '__main__':
    unittest.main()

On my machine, this test shows that my "dummy" method is about 47% faster than torchaudio...

Additional context

No response

pfeatherstone commented 5 months ago

Now if you try with torch.float16 or torch.bfloat16, the torchaudio method doesn't work.