v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

Fix non-default axis in packets #94

Closed felixblanke closed 3 months ago

felixblanke commented 3 months ago

The logic in the WaveletPacket objects so far ignored the axis argument and always accessed the last axes, e.g. to lookup which transform to use.

This PR addresses this and extends the unit tests to cover different axes choices.

v0lta commented 3 months ago

It works. It fixes a bug. I am merging.