Closed felixblanke closed 4 months ago
I did not add n-dimensional separable transforms on purpose, because I was thinking people will ask for these in all other cases, too, where these are trickier to deliver.
@v0lta I made the n-dim transform private. Does that work?
Yes, that works. However, we won't be able to support n-dimensional transforms across the board because PyTorch does not provide the interfaces we would need to do that. Padding, for example, works only up until 3D ( https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html ). We have the same problem with isotropic convolution. So, I think we should communicate that nd-transforms are out of the scope of this project.
In general, I am a big fan of this full request! Thanks at @felixblanke I am going to clean up the docs for _map_result
and commit here.
Our coeff_tree_map
is not a general tree map, but it does not have to be since we know the approximation tensor will always be the first entry. I ran the not-slow
tests. Everything checked out. The code is cleaner now. If everyone is on board, I would be ready to merge.
I think we so far only refer to the Packet data structure as a tree. Maybe we can add a link to the JAX discussion as a reference?
I am not sure if users need to know. I think this is more for us here internally. Unlike Jax's tree map, ours is coefficient-specific, hence the proposed coeff_tree_map
function name.
Actually, never mind. The user argument does not matter since it's a private function. If you think it helps to understand the idea, please add a link. I think it might help potential future contributors.
Okay, I think we are ready to merge. @felixblanke @cthoyt is everyone on board?
okay let's merge!
This addresses #92.
For all discrete transforms, the preprocessing and postprocessing of coefficients and tensors is very similar (i.e. folding and swapping of axes, adding batch dims, etc.). This PR moves this functionality into shared functions that use
_map_result
.Also the check for consistent devices and dtypes between coefficient tensors is moved into a function in
_utils
.Last, as it was possible to add it with a few lines of code, I added the $n$-dimensional fully separable transforms (
fswavedecn
,fswaverecn
). If this is not wanted, I can revert their addition.Further, I did some minor refactorings along the way.