v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

SWT: Make circular padding wrap more than once if needed #84

Closed NiclasPi closed 3 months ago

NiclasPi commented 5 months ago

Some combinations of wavelets and number of decomposition steps cause the following RuntimeError: "Padding value causes wrapping around more than once". If the padding values are greater than the input length, torch.nn.functional.pad does not wrap around more than once in circluar mode. We need to manually wrap around more than once, if needed. See feature request https://github.com/pytorch/pytorch/issues/57911 in PyTorch.

NiclasPi commented 4 months ago

Thank you for the review @cthoyt ! I added explainig comments and some tests comparing the function with numpys wrap padding.

v0lta commented 4 months ago

For posterity: docs to checkout forked code: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/reviewing-changes-in-pull-requests/checking-out-pull-requests-locally#modifying-an-inactive-pull-request-locally

v0lta commented 4 months ago

I see failed tests on my machine: tests/test_swt.py .....FFF.........FFF.........FFF.........FFF................................................................... [ 98%] ........

test_swt_1d[db1-size0-1], test_swt_1d[db1-size0-2], test_swt_1d[db1-size0-None], test_swt_1d[db2-size0-1], test_swt_1d[db2-size0-2], test_swt_1d[db2-size0-None] and others dont pass. I will look into this tomorrow. Perhaps I can figure out whats going on. @NiclasPi does everything pass on your machine?

v0lta commented 4 months ago
FAILED tests/test_swt.py::test_swt_1d[db1-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db1-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db1-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-None] - assert False

a partial log from nox -s test

cthoyt commented 4 months ago

@v0lta we need to update this line https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/1b75acdbe7c9cbfb526fd1a7c19ad260b935c6ec/.github/workflows/tests.yml#L3 to make tests run on external PRs, it should look like this https://github.com/biopragmatics/curies/blob/470f71a69264c17260823d485de113695076a38f/.github/workflows/tests.yml#L3-L7

v0lta commented 4 months ago

Done in the main branch.

cthoyt commented 4 months ago

@v0lta @NiclasPi the tests are now showing up properly in the PR

v0lta commented 4 months ago

TODO: Port shape fold-unfolding code to swt, properly.

NiclasPi commented 3 months ago

Sorry, I forgot to mention that I added test cases that weren't covered yet. Specifically, since I am working with 1-dimensional signals, I added cases with shape = (N,). The code base currently only tests 1-dimensional signals with shape = (1, N). After closer inspection it turned out to be a problem in the folding-unfolding helper functions. These should also work with shape = (N,). @v0lta is going to look into this issue and will fix it. Afterwards, the failing tests for the SWT should pass.

v0lta commented 3 months ago

Thanks @NiclasPi . I have a fix in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/fix-padding which starts to address this. We currently get (1, T) outputs shapes for (N,) inputs across the entire toolbox. With N, the input measurements and T for the dimension at some transform level. This is not what we want since the Toolbox should always respect the uses choice of input dimension. I will fix this across the board. @cthoyt how did you [Merge branch 'main' into pr/84]. My changes ended up in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/fix-padding , unfortunately. I think I should know this but here we are ;-).

v0lta commented 3 months ago

It's not a catastrophic problem users can currently fix this by running out.squeeze() themselves, but since we want to be pywt-compatible the shapes should be identical, so we'll fix this.

cthoyt commented 3 months ago

you need to add the remote (see https://docs.github.com/en/get-started/getting-started-with-git/managing-remote-repositories)

git remote add NiclasPi https://github.com/NiclasPi/PyTorch-Wavelet-Toolbox.git
git fetch --all

then you can switch to the right fix-padding branch

v0lta commented 3 months ago

Thanks @cthoyt ! I think we are almost there now. The last thing I would like to do here ist get rid of _conv_transpose_dedilate and use group convolution instead.

v0lta commented 3 months ago

Hi Team, I removed the _conv_transpose_dedilate; the code is simpler and faster now. I also made the SWT module public by removing the underscore. From my point of view, this PR is ready for a merge.