fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

How to change default tensor dtype in pytorch_wavelets #47

Closed zxyl1003 closed 1 year ago

zxyl1003 commented 1 year ago

I want to use torch.float16 data type when DWT, but i got a error:

RuntimeError: expected scalar type Half but found Float

I think the reason is that the data type of the DWT function is torch.float32, but input data type is torch.float16. Is there any way to change the default torch.float32 type of the DWT function to torch.float16.

abeyang00 commented 1 year ago

Use pytorch_wavelets.DWTForward(wave='haar').to(torch.float16)

zxyl1003 commented 1 year ago

Thank you!