v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

Make preprocessing and postprocessing consistent accross transforms #93

Closed felixblanke closed 3 months ago

felixblanke commented 3 months ago

This addresses #92.

For all discrete transforms, the preprocessing and postprocessing of coefficients and tensors is very similar (i.e. folding and swapping of axes, adding batch dims, etc.). This PR moves this functionality into shared functions that use _map_result.

Also the check for consistent devices and dtypes between coefficient tensors is moved into a function in _utils.

Last, as it was possible to add it with a few lines of code, I added the $n$-dimensional fully separable transforms (fswavedecn, fswaverecn). If this is not wanted, I can revert their addition.

Further, I did some minor refactorings along the way.

v0lta commented 3 months ago

I did not add n-dimensional separable transforms on purpose, because I was thinking people will ask for these in all other cases, too, where these are trickier to deliver.

felixblanke commented 3 months ago

@v0lta I made the n-dim transform private. Does that work?

v0lta commented 3 months ago

Yes, that works. However, we won't be able to support n-dimensional transforms across the board because PyTorch does not provide the interfaces we would need to do that. Padding, for example, works only up until 3D ( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html ). We have the same problem with isotropic convolution. So, I think we should communicate that nd-transforms are out of the scope of this project.

v0lta commented 3 months ago

In general, I am a big fan of this full request! Thanks at @felixblanke I am going to clean up the docs for _map_result and commit here.

v0lta commented 3 months ago

Our coeff_tree_map is not a general tree map, but it does not have to be since we know the approximation tensor will always be the first entry. I ran the not-slow tests. Everything checked out. The code is cleaner now. If everyone is on board, I would be ready to merge.

felixblanke commented 3 months ago

I think we so far only refer to the Packet data structure as a tree. Maybe we can add a link to the JAX discussion as a reference?

v0lta commented 3 months ago

I am not sure if users need to know. I think this is more for us here internally. Unlike Jax's tree map, ours is coefficient-specific, hence the proposed coeff_tree_map function name.

v0lta commented 3 months ago

Actually, never mind. The user argument does not matter since it's a private function. If you think it helps to understand the idea, please add a link. I think it might help potential future contributors.

v0lta commented 3 months ago

Okay, I think we are ready to merge. @felixblanke @cthoyt is everyone on board?

v0lta commented 3 months ago

okay let's merge!