v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
279 stars 36 forks source link

Add lazy init to packets for partial tree expansion #96

Closed felixblanke closed 3 months ago

felixblanke commented 3 months ago

Adds lazy initialization to wavelet packet objects and addresses #91.

Adds the option to lazy initialize a wavelet packet object by passing lazy_init=True to __init__ or transform. In the case, we avoid doing a expansion of the full packet tree. If the user tries to access access a key not yet contained in the dict, we calculate the missing coeffs on the fly.

This allows for siginificant speedups if we are not interested in the full tree but only a subset of the nodes

The example of issue #69 is presented below:

import torch, ptwt
max_lev = 4
shape = (512, 512)

test_signal = torch.randn(shape)
full_packet = ptwt.WaveletPacket2D(test_signal, "haar", maxlevel=max_lev)
partial_packet = ptwt.WaveletPacket2D(test_signal, "haar", maxlevel=max_lev, lazy_init=True)

# Full expansion of the wavelet packet tree
wp_keys = ptwt.WaveletPacket2D.get_natural_order(max_lev)

# Partial expansion
keys = ['aaaa', 'aaad', 'aaah', 'aaav', 'aad', 'aah', 'aava', 'aavd',
        'aavh', 'aavv', 'ad', 'ah', 'ava', 'avd', 'avh', 'avv', 'd', 'h',
        'vaa', 'vad', 'vah', 'vav', 'vd', 'vh', 'vv']

print("Partial expansion: keys contained?", all(key in partial_packet for key in keys))

print("Init...")
# lazy initialization
[partial_packet[key] for key in keys]

print("Partial expansion: keys contained?", all(key in partial_packet for key in keys))
print("Partial expansion: wp_keys contained?", all(key in partial_packet for key in wp_keys))

print()
diffs = [((partial_packet[key] - full_packet[key]) ** 2).sum() for key in keys]
print(f"Squared difference: {sum(diffs)=}")

print(f"# Partial keys: {len(partial_packet.keys())}")
print(f"# Full keys: {len(full_packet.keys())}")

which outputs

Partial expansion: keys contained? False
Init...
Partial expansion: keys contained? True
Partial expansion: wp_keys contained? False

Squared difference: sum(diffs)=tensor(0.)
# Partial keys: 33
# Full keys: 341
v0lta commented 3 months ago

I was merging, but then I thought that not computing the entire tree immediately should never be a problem. We eagerly computed the entire tree because we knew we needed it for the deepfake detection project. But generally speaking, we don't know which parts of the tree a user might want to expand. So I am thinking we should never compute the entire tree by default.

v0lta commented 3 months ago

If that's true, we don't need the lazy_init argument and can just change the behaviour under the hood. I am putting a PR together.

felixblanke commented 3 months ago

I merged main into this PR. The product of the tested parameters gets quite high, running the full test suite show 11227 test for the packets module alone. We might want to consider reducing this

v0lta commented 3 months ago

I think we should not have the lazy_init argument and make computation on request the default. That way this PR does not add extra tests.

felixblanke commented 3 months ago

We could change the default value of lazy_init to True. Then users could decide to opt in to eager initialization.

v0lta commented 3 months ago

But then we have to keep all of the old eager code. I don't see why we would do that. It just adds extra complexity.

v0lta commented 3 months ago

Wait, let's not duplicate work. @felixblanke, are you doing this already?

v0lta commented 3 months ago

okay I am doing it.

felixblanke commented 3 months ago

Ah, wait. I am on it :D

v0lta commented 3 months ago

Okay I am done and running the tests.

v0lta commented 3 months ago

It looks like we can just remove the old recursive code without too much of a hassle.

v0lta commented 3 months ago

Turns out it works in most cases:

FAILED tests/test_packets.py::test_partial_expansion_1d[zero-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[zero-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[reflect-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[reflect-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[constant-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[constant-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[boundary-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_1d[boundary-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[zero-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[zero-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[reflect-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[reflect-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[constant-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[constant-db4] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[boundary-haar] - assert False
FAILED tests/test_packets.py::test_partial_expansion_2d[boundary-db4] - assert False
FAILED tests/test_packets.py::test_inverse_boundary_packet_1d - KeyError: 'Key da not found'
FAILED tests/test_packets.py::test_inverse_boundary_packet_2d - KeyError: 'Key aa not found'

failed with the current code.

v0lta commented 3 months ago

I am leaving for today, but am happy to look at this again tomorrow.

felixblanke commented 3 months ago

I added the function initialize to the packet objects to initialize all coefficients as described by a set of keys. This feels less clumsy than using a list comprehension (and is possibly more memory efficient)

v0lta commented 3 months ago

I am convinced this works. Lets merge.