aiqm / torchani

Accurate Neural Network Potential on PyTorch
https://aiqm.github.io/torchani/
MIT License
464 stars 128 forks source link

Add a function to recast all buffer tensors #473

Closed IgnacioJPickering closed 4 years ago

IgnacioJPickering commented 4 years ago

When using the model in a C++ environment, model.to(torch::kDouble) converts all Parameters AND all buffers to double. As a consequence all torch.long buffers in the model are converted to torch.double, which we definitely don't want.

After this it becomes a hassle to specifically convert named buffers into doubles, I didn't find anything in the C++ API that exposed a handle to a torch::jit::script::Module buffer, so the workaround is to add a function that recasts the torch.long buffers to double once again on the python side

I'm not sure if this is a feature or a bug of libtorch, but I think it differs from the python behaviour

Sorry for the long explanation but this is complex.