Closed fsherry closed 4 years ago
Thanks @fsherry! It's been a while since I looked at the precision code, but I think my modules should accept both. If you call torch.set_default_dtype(torch.float64)
it might just work. Let me know if it doesn't and I'll have a look into it.
I have some tests for double precision in https://github.com/fbcotter/pytorch_wavelets/blob/52964116161a6f6ba409586e42d7c85f64cb36cb/tests/test_dwt.py#L161 which leads me to believe it should work
Hi @fbcotter, thanks very much for your help. I didn't know about torch.set_default_dtype
before, but using it everything works as you mentioned.
Thank you for publishing this package, I have found it very useful. In my application I use double precision floating point numbers, so I found that I had to change some of the dtypes in your code to make it work for me. Would you consider adding an optional dtype argument to your module constructors to accommodate using other dtypes more easily?