fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

maximum recursion depth exceeded in comparison #57

Open rachelglenn opened 3 months ago

rachelglenn commented 3 months ago

Hi. I am trying to run the example provided (version is pytorch_wavelets 1.3.0 and torch 2.1.0):

import torch
from pytorch_wavelets import DWTForward, DWTInverse

xfm = DWTForward(J=3, wave='db3', mode='zero')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X)
print(Yl.shape)
torch.Size([10, 5, 12, 12])
print(Yh[0].shape)
torch.Size([10, 5, 3, 34, 34])
print(Yh[1].shape)
torch.Size([10, 5, 3, 19, 19])
print(Yh[2].shape)
torch.Size([10, 5, 3, 12, 12])
ifm = DWTInverse(wave='db3', mode='zero')
Y = ifm((Yl, Yh))

RecursionError                            Traceback (most recent call last)
Cell In[4], line 6
      4 xfm = DWTForward(J=3, wave='db3', mode='zero')
      5 X = torch.randn(10,5,64,64)
----> 6 Yl, Yh = xfm(X)
      7 print(Yl.shape)
      8 torch.Size([10, 5, 12, 12])

File [~/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1518](https://vscode-remote+kubeflow-002eapps-002epcell-002eai-002eus-002elmco-002ecom.vscode-resource.vscode-cdn.net/home/jovyan/umbra/~/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
   16     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   17 else:
-> 18     return self._call_impl(*args, **kwargs)

File [~/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1527](/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
   22 # If we don't have any hooks, we want to skip the rest of the logic in
   23 # this function, and just call forward.
   24 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   25         or _global_backward_pre_hooks or _global_backward_hooks
   26         or _global_forward_hooks or _global_forward_pre_hooks):
-> 27     return forward_call(*args, **kwargs)
   29 try:
   30     result = None

File [~/home/lib/python3.9/site-packages/pytorch_wavelets/dwt/transform2d.py:70](~/home/lib/python3.9/site-packages/pytorch_wavelets/dwt/transform2d.py:70), in DWTForward.forward(self, x)
...
--> 276     if mode == 'zero':
    277         return 0
    278     elif mode == 'symmetric':

RecursionError: maximum recursion depth exceeded in comparison