fbcotter / pytorch_wavelets

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet
Other
943 stars 146 forks source link

Support for double precision #16

Closed fsherry closed 4 years ago

fsherry commented 4 years ago

Thank you for publishing this package, I have found it very useful. In my application I use double precision floating point numbers, so I found that I had to change some of the dtypes in your code to make it work for me. Would you consider adding an optional dtype argument to your module constructors to accommodate using other dtypes more easily?

fbcotter commented 4 years ago

Thanks @fsherry! It's been a while since I looked at the precision code, but I think my modules should accept both. If you call torch.set_default_dtype(torch.float64) it might just work. Let me know if it doesn't and I'll have a look into it.

fbcotter commented 4 years ago

I have some tests for double precision in https://github.com/fbcotter/pytorch_wavelets/blob/52964116161a6f6ba409586e42d7c85f64cb36cb/tests/test_dwt.py#L161 which leads me to believe it should work

fsherry commented 4 years ago

Hi @fbcotter, thanks very much for your help. I didn't know about torch.set_default_dtype before, but using it everything works as you mentioned.