Closed NiclasPi closed 3 months ago
Thank you for the review @cthoyt ! I added explainig comments and some tests comparing the function with numpys wrap padding.
I see failed tests on my machine:
tests/test_swt.py .....FFF.........FFF.........FFF.........FFF................................................................... [ 98%] ........
test_swt_1d[db1-size0-1]
, test_swt_1d[db1-size0-2]
, test_swt_1d[db1-size0-None]
, test_swt_1d[db2-size0-1]
, test_swt_1d[db2-size0-2]
, test_swt_1d[db2-size0-None]
and others dont pass.
I will look into this tomorrow. Perhaps I can figure out whats going on. @NiclasPi does everything pass on your machine?
FAILED tests/test_swt.py::test_swt_1d[db1-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db1-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db1-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db2-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db3-size0-None] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-1] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-2] - assert False
FAILED tests/test_swt.py::test_swt_1d[db4-size0-None] - assert False
a partial log from nox -s test
@v0lta we need to update this line https://github.com/v0lta/PyTorch-Wavelet-Toolbox/blob/1b75acdbe7c9cbfb526fd1a7c19ad260b935c6ec/.github/workflows/tests.yml#L3 to make tests run on external PRs, it should look like this https://github.com/biopragmatics/curies/blob/470f71a69264c17260823d485de113695076a38f/.github/workflows/tests.yml#L3-L7
Done in the main branch.
@v0lta @NiclasPi the tests are now showing up properly in the PR
TODO: Port shape fold-unfolding code to swt, properly.
Sorry, I forgot to mention that I added test cases that weren't covered yet. Specifically, since I am working with 1-dimensional signals, I added cases with shape = (N,)
. The code base currently only tests 1-dimensional signals with shape = (1, N)
.
After closer inspection it turned out to be a problem in the folding-unfolding helper functions. These should also work with shape = (N,)
. @v0lta is going to look into this issue and will fix it. Afterwards, the failing tests for the SWT should pass.
Thanks @NiclasPi . I have a fix in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/fix-padding which starts to address this. We currently get (1, T) outputs shapes for (N,) inputs across the entire toolbox. With N, the input measurements and T for the dimension at some transform level. This is not what we want since the Toolbox should always respect the uses choice of input dimension. I will fix this across the board.
@cthoyt how did you [Merge branch 'main' into pr/84]
. My changes ended up in https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/fix-padding , unfortunately. I think I should know this but here we are ;-).
It's not a catastrophic problem users can currently fix this by running out.squeeze() themselves, but since we want to be pywt-compatible the shapes should be identical, so we'll fix this.
you need to add the remote (see https://docs.github.com/en/get-started/getting-started-with-git/managing-remote-repositories)
git remote add NiclasPi https://github.com/NiclasPi/PyTorch-Wavelet-Toolbox.git
git fetch --all
then you can switch to the right fix-padding branch
Thanks @cthoyt ! I think we are almost there now. The last thing I would like to do here ist get rid of _conv_transpose_dedilate
and use group convolution instead.
Hi Team, I removed the _conv_transpose_dedilate
; the code is simpler and faster now. I also made the SWT module public by removing the underscore. From my point of view, this PR is ready for a merge.
Some combinations of wavelets and number of decomposition steps cause the following
RuntimeError
: "Padding value causes wrapping around more than once". If the padding values are greater than the input length,torch.nn.functional.pad
does not wrap around more than once in circluar mode. We need to manually wrap around more than once, if needed. See feature request https://github.com/pytorch/pytorch/issues/57911 in PyTorch.