Open rachelglenn opened 3 months ago
Hi. I am trying to run the example provided (version is pytorch_wavelets 1.3.0 and torch 2.1.0):
import torch from pytorch_wavelets import DWTForward, DWTInverse xfm = DWTForward(J=3, wave='db3', mode='zero') X = torch.randn(10,5,64,64) Yl, Yh = xfm(X) print(Yl.shape) torch.Size([10, 5, 12, 12]) print(Yh[0].shape) torch.Size([10, 5, 3, 34, 34]) print(Yh[1].shape) torch.Size([10, 5, 3, 19, 19]) print(Yh[2].shape) torch.Size([10, 5, 3, 12, 12]) ifm = DWTInverse(wave='db3', mode='zero') Y = ifm((Yl, Yh)) RecursionError Traceback (most recent call last) Cell In[4], line 6 4 xfm = DWTForward(J=3, wave='db3', mode='zero') 5 X = torch.randn(10,5,64,64) ----> 6 Yl, Yh = xfm(X) 7 print(Yl.shape) 8 torch.Size([10, 5, 12, 12]) File [~/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1518](https://vscode-remote+kubeflow-002eapps-002epcell-002eai-002eus-002elmco-002ecom.vscode-resource.vscode-cdn.net/home/jovyan/umbra/~/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs) 16 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 17 else: -> 18 return self._call_impl(*args, **kwargs) File [~/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1527](/home/lib/python3.9/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs) 22 # If we don't have any hooks, we want to skip the rest of the logic in 23 # this function, and just call forward. 24 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 25 or _global_backward_pre_hooks or _global_backward_hooks 26 or _global_forward_hooks or _global_forward_pre_hooks): -> 27 return forward_call(*args, **kwargs) 29 try: 30 result = None File [~/home/lib/python3.9/site-packages/pytorch_wavelets/dwt/transform2d.py:70](~/home/lib/python3.9/site-packages/pytorch_wavelets/dwt/transform2d.py:70), in DWTForward.forward(self, x) ... --> 276 if mode == 'zero': 277 return 0 278 elif mode == 'symmetric': RecursionError: maximum recursion depth exceeded in comparison
Hi. I am trying to run the example provided (version is pytorch_wavelets 1.3.0 and torch 2.1.0):