Closed maxjiang93 closed 6 years ago
It has been mentioned in https://github.com/pytorch/extension-cpp/issues/14. Let me know if those fixes work for you
Oh sorry I didn't notice #14. Though I'm not sure what the suggested fix was in #14 is. It seems that OP's suggested fix was to recompile Pytorch from scratch? I didn't go through recompiling Pytorch from scratch, my fix above by specifically type casting them to be scalar_t
seems easier :)
I failed to compile the cuda code:
python setup.py install
and I'm rather surprised that this issues has not been brought up before. Here's the error message:Most of this is probably irrelevant except for gcc version:
Here's my hacky fix that worked, by simply wrapping scalar_t around the doubles. Not sure this is the most elegant solution.
lltm_cuda_kernel.cu
lines 26-29: