I have a use case for convolving two tensors with dtypes torch.float16 or torch.bfloat16 containing interleaved complex data.
Motivation, pitch
I have a couple DSP functions being applied in a training loop using float16 or bfloat16.
Alternatives
I have the following code which works
import torch
import torch.nn.functional as F
from einops import rearrange
def convolve_cplx (
x: torch.Tensor, # [B, T, 2]
y: torch.Tensor # [B, H, 2]
) :
""" Same as torch.view_as_real(torchaudio.convolve(torch.view_as_complex()...) but works with torch.float16 and torch.bfloat16"""
B, m = x.shape[0], int(y.shape[1]//2)
y = y.flip(-2) # [B H 2]
yr, yi = y[...,0], y[...,1]
x = rearrange(x, 'b t re -> 1 (b re) t')
w = rearrange([yr,-yi,yi,yr], '(o re) b h -> (b o) re h', re=2)
y = F.conv1d(x, w, padding=m, groups=B)
y = rearrange(y, '1 (b re) t -> b t re', re=2)
return y
I have a unit test:
import unittest
import time
import torchaudio.functional as aF
class TestDSP(unittest.TestCase):
def test_conv_cplx(self):
perf0 = 0
perf1 = 0
first = True
for T in [2000, 1000, 500, 200]:
for H in [21, 5, 9, 161]:
x = torch.randn(32, T, 2)
h = torch.randn(32, H, 2)
t0 = time.perf_counter()
y1 = torch.view_as_real(aF.convolve(torch.view_as_complex(x), torch.view_as_complex(h), mode='same'))
t1 = time.perf_counter()
y2 = convolve_cplx(x, h)
t2 = time.perf_counter()
self.assertTrue(torch.allclose(y1,y2,rtol=1e-4,atol=1e-4))
if first:
first = False
else:
perf0 += (t1-t0)
perf1 += (t2-t1)
print("torchaudio time {:.5f} custom {:.5f} improvement {:.2f}%".format(perf0, perf1, 100*(perf0-perf1)/perf0))
if __name__ == '__main__':
unittest.main()
On my machine, this test shows that my "dummy" method is about 47% faster than torchaudio...
🚀 The feature
I have a use case for convolving two tensors with dtypes torch.float16 or torch.bfloat16 containing interleaved complex data.
Motivation, pitch
I have a couple DSP functions being applied in a training loop using float16 or bfloat16.
Alternatives
I have the following code which works
I have a unit test:
On my machine, this test shows that my "dummy" method is about 47% faster than torchaudio...
Additional context
No response