v0lta / PyTorch-Wavelet-Toolbox

Differentiable fast wavelet transforms in PyTorch with GPU support.
https://pytorch-wavelet-toolbox.readthedocs.io
European Union Public License 1.2
289 stars 36 forks source link

3D doesn't work #82

Closed mahfuzalhasan closed 4 months ago

mahfuzalhasan commented 7 months ago

I have tried to pass tensor of shape B,C,D,H,W but internally it adds a dimension after B. So the tensor becomes B,1,C,D,H,W. Then it fails to conv3d. Is there anyway to resolve this?

v0lta commented 7 months ago

Dear @mahfuzalhasan ,

running,

import ptwt
import torch
data = torch.randn(1,2,64,64,64)
print([(key, coeff.shape) for key, coeff in ptwt.wavedec3(data, wavelet="haar", level=5)[-1].items()])

prints

[('aad', torch.Size([1, 2, 32, 32, 32])), ('ada', torch.Size([1, 2, 32, 32, 32])), ('add', torch.Size([1, 2, 32, 32, 32])), ('daa', torch.Size([1, 2, 32, 32, 32])), ('dad', torch.Size([1, 2, 32, 32, 32])), ('dda', torch.Size([1, 2, 32, 32, 32])), ('ddd', torch.Size([1, 2, 32, 32, 32]))]

I don't see the extra dimension here. Could you please provide a minimal code example, which allows the reproduction of your problem?

mahfuzalhasan commented 4 months ago

Greetings Volta,

Sorry that I got disconnected before rsolving the issue. The error that I am getting running the same code you provided is the following:

../lib/python3.8/site-packages/ptwt/conv_transform_3.py", line 159, in wavedec3
    res = torch.nn.functional.conv3d(res_lll, dec_filt, stride=2)
RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv3d, but got input of size: [1, 1, 2, 64, 64, 64].

Can you please tell me what should I do to resolve this? I am using ptwt 0.1.6

v0lta commented 4 months ago

Dear @mahfuzalhasan, version 0.1.7 introduced support for multi-dimensional inputs. Please update your installation to a more recent version by typing

pip install --upgrade ptwt

.

v0lta commented 4 months ago

To upgrade, you must install a more recent version of Python. With the 0.1.7 release, we dropped support for Python version 3.8.

v0lta commented 4 months ago

I am closing this due to a lack of activity. Feel free to reopen.