pytorch / audio

Data manipulation and transformation for audio signal processing, powered by PyTorch
https://pytorch.org/audio
BSD 2-Clause "Simplified" License
2.55k stars 657 forks source link

Fix view size error when backpropagating through lfilter #3794

Closed yoyolicoris closed 5 months ago

yoyolicoris commented 6 months ago

This commit solves the view size error when the gradient tensor propagates to the filter is not in a contiguous format.

The bug can be reproduced by the following script (provided by @forgi86) using the latest version of torchaudio.

import torch
import torchaudio.functional

order = 5
in_channels = 10
seq_len = 1000
batch_size = 32

b_coeff = torch.randn((in_channels, order)) * 0.05
b_coeff = b_coeff.requires_grad_(True)
a_coeff = torch.randn((in_channels, order - 1)) * 0.05
a_coeff = a_coeff.requires_grad_(True)

ones = torch.ones_like(a_coeff[..., :1])
a_coeffs_with_ones = torch.cat((ones, a_coeff), dim=-1)

u = torch.randn((batch_size, in_channels, seq_len))

y = torchaudio.functional.lfilter(u, a_coeffs_with_ones, b_coeff, clamp=False, batching=True)
y = y.transpose(-2, -1).reshape(batch_size, in_channels // 2, 2, seq_len)

loss = torch.sum(y**2)
loss.backward()
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
pytorch-bot[bot] commented 6 months ago

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/audio/3794

Note: Links to docs will display an error until the docs builds have been completed.

:x: 4 New Failures

As of commit 8745e8c5644d7a0867ae3ff37980ce0a5fc37160 with merge base b829e936f7cc61b48149f5f957a451a38bf2a178 (image):

NEW FAILURES - The following jobs have failed:

* [Unit-tests on Linux GPU / tests (3.8, 11.8) / linux-job](https://hud.pytorch.org/pr/pytorch/audio/3794#26598430899) ([gh](https://github.com/pytorch/audio/actions/runs/9645040321/job/26598430899)) `##[error]The operation was canceled.` * [Unit-tests on Linux GPU / tests (3.9, 11.8) / linux-job](https://hud.pytorch.org/pr/pytorch/audio/3794#26598431335) ([gh](https://github.com/pytorch/audio/actions/runs/9645040321/job/26598431335)) `##[error]The operation was canceled.` * [Unittests on Windows CPU / unittests-windows-cpu / windows-job](https://hud.pytorch.org/pr/pytorch/audio/3794#26598429892) ([gh](https://github.com/pytorch/audio/actions/runs/9645040317/job/26598429892)) `The process 'C:\Program Files\Git\cmd\git.exe' failed with exit code 128` * [Unittests on Windows GPU / unittests-windows-gpu / windows-job](https://hud.pytorch.org/pr/pytorch/audio/3794#26598429914) ([gh](https://github.com/pytorch/audio/actions/runs/9645040351/job/26598429914)) `The process 'C:\Program Files\Git\cmd\git.exe' failed with exit code 128`

This comment was automatically generated by Dr. CI and updates every 15 minutes.