v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

Avoid adding of a batch dim on 2d signals. #92

Closed felixblanke closed 3 months ago

felixblanke commented 3 months ago

For 1d signals the coefficients as well as reconstructed signals keep the same number of dimensions as the original input signal. However, for 2d signals with only 2 axes we currently always add a superfluous batch dimension, i.e. [H, W] -> [1, H, W].

This seems inconsistent.

We can address this by introducing a functionality similiar to ptwt.conv_transform._postprocess_result_list_dec1d for higher dimensions.

v0lta commented 3 months ago

https://github.com/v0lta/PyTorch-Wavelet-Toolbox/pull/93 is merged.