Closed bwpriest closed 1 year ago
This was an easy fix - I just removed the .float()
calls from the code, and with the new refactoring there shouldn't be any hardcoding of this anywhere in the library.
Ah it's always nice when the fix is easy.
Should I remove the flag in MuyGPyS.examples.muygps_torch
that forces an ftype of 32?
I guess we can? On the other hand it seems like we should keep to the "default" for torch, whereas the default for MuyGPyS is 64 bit. I'm torn.
Let's leave the flag for now.
This issue was addressed in PR #146
We currently need to use
$ export MUYGPYS_FTYPE=32
forMuyGPyS.torch.muygpys_layer
to perform correctly during optimization. This is because.float()
is hardcoded therein. We need to modify this behavior so that it depends onmm.ftype
.