v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

Allow partial refinement in Wavelet Packets #91

Closed felixblanke closed 3 months ago

felixblanke commented 3 months ago

Currently, we only support uniform "refinement" in the wavelet packets by passing a maxlevel to the transform function. Then the full packet tree is expanded up to that depth.

What is currently not supported directly are partial expansions as asked in #69.

I can think of two ways to implement this (which are complementary):

  1. Give the expansion mechanism (which is currently implemented in _recursive_dwt) a public interface
  2. Add an optional callback parameter to the transform method that given a tree path string (i.e. the coeff key in the packet) returns a bool of whether this coeff should be expanded.
felixblanke commented 3 months ago

I decided against a callback. Using lazy init instead seems less confusing.

v0lta commented 3 months ago

I am closing this since we merged the PR.